使用Pycoral库和GoogleCoralUSB加速器进行二进制图像分类
我有一个使用 sigmoid 函数进行二元分类的 Keras 模型。我编译了我的模型以.tflite按照 Coral USB 的要求进行格式化以运行推理。但是,我注意到该脚本classify_image.py执行多类分类。因此,当我尝试对图像进行分类时,我对任何图像都得到了 100% 的预测。例如,我的模型将红外图像分类为发烧状态。即使我传球图像,它也会为发烧级提供 100% 阳性。
因此,我再次使用完全自定义的模型对植物使用 layer softmax 测试了多类模型,这一次它起作用了。它为植物 A、植物 B 和植物 C 提供了合理的 85% 准确度。
因此,我想知道我需要做哪些更改,才能使用二进制分类自定义模型与 Pycoral 配合使用。
这是我用于分类的代码:
import argparse
import time
from PIL import Image
from pycoral.adapters import classify
from pycoral.adapters import common
from pycoral.utils.dataset import read_label_file
from pycoral.utils.edgetpu import make_interpreter
import cv2 as cv
import numpy as np
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-m', '--model', required=True,
help='File path of .tflite file.')
parser.add_argument('-i', '--input', required=True,
help='Image to be classified.')
parser.add_argument('-l', '--labels',
help='File path of labels file.')
parser.add_argument('-k', '--top_k', type=int, default=2,
help='Max number of classification results')
parser.add_argument('-t', '--threshold', type=float, default=0.0,
help='Classification score threshold')
parser.add_argument('-c', '--count', type=int, default=5,
help='Number of times to run inference')
args = parser.parse_args()
labels = read_label_file(args.labels) if args.labels else {}
interpreter = make_interpreter(*args.model.split('@'))
interpreter.allocate_tensors()
print(interpreter)
size = common.input_size(interpreter)
image = cv.imread(args.input)
image = cv.normalize(image, image, 0, 255, cv.NORM_MINMAX)
common.set_input(interpreter, image)
print('----INFERENCE TIME----')
print('Note: The first inference on Edge TPU is slow because it includes',
'loading the model into Edge TPU memory.')
for _ in range(args.count):
start = time.perf_counter()
interpreter.invoke()
inference_time = time.perf_counter() - start
classes = classify.get_classes(interpreter, args.top_k, args.threshold)
print('%.1fms' % (inference_time * 1000))
print('-------RESULTS--------')
for c in classes:
print('%s: %.5f' % (labels.get(c.id, c.id), c.score))
if __name__ == '__main__':
main()
我的labels.txt只有发烧(正类)和健康(负类)两个标签。用于二元模型分类的阈值为 0.50,据我所知,模型层与 Coral USB Accelerator 设备完全兼容。
THE END
二维码