64 lines
1.9 KiB
Python
64 lines
1.9 KiB
Python
import cv2
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from utils.quantize import Dtype
|
|
from utils.common import load_image_cv, load_image_pil, init_service
|
|
|
|
|
|
def infer(img_data, web_service):
|
|
"""下位机推理
|
|
|
|
Args:
|
|
img_data (np.array): 推理数据
|
|
web_service: 实例化service
|
|
|
|
Returns:
|
|
推理结果信息
|
|
"""
|
|
rsp = web_service.infer(img_data)
|
|
infer_result0 = np.frombuffer(rsp['output_0.dat'], dtype=np.float32).flatten()
|
|
np_reshaped = infer_result0.reshape(1, 1000)
|
|
pred = torch.from_numpy(np_reshaped)
|
|
|
|
return pred
|
|
|
|
def post_process(rsp, image_src):
|
|
"""后处理
|
|
|
|
Args:
|
|
rsp : 预测框列表
|
|
image_src : 原图
|
|
|
|
Returns:
|
|
目标分类结果图
|
|
"""
|
|
infer_result0 = np.frombuffer(rsp['output_0.dat'], dtype=np.float32).flatten()
|
|
np_reshaped = infer_result0.reshape(1, 1000)
|
|
pred = torch.from_numpy(np_reshaped)
|
|
pred = F.softmax(pred, dim=1) # probabilities
|
|
|
|
# 获取前5名的置信度和类别
|
|
top5_conf, top5_labels = torch.topk(pred, k=5)
|
|
image_src = cv2.resize(image_src, (640, 640))
|
|
y_offset = 30 # 初始垂直偏移
|
|
# 在图像上绘制每个检测结果
|
|
for conf, cls_id in zip(top5_conf[0].numpy(), top5_labels[0].numpy()):
|
|
text = f"id: {cls_id}: {conf:.4f}"
|
|
cv2.putText(image_src, text, (20, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.7,
|
|
(255, 0, 0), 2)
|
|
y_offset += 30 # 每行下移30像素
|
|
return image_src
|
|
|
|
if __name__ == "__main__":
|
|
image_list = []
|
|
image_path = './models/resnet18/1.jpg'
|
|
image = load_image_cv(image_path)
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
image = cv2.resize(image, (224, 224))
|
|
image_list.append(image)
|
|
web_service = init_service('./models/resnet18', Dtype.U8)
|
|
rsp = web_service.infer(image_list)
|
|
result_image = post_process(rsp, image)
|
|
cv2.imwrite('./models/resnet18/result.jpg', result_image)
|