demo_system_algorithm/test_src/demo_resnet18.py

64 lines
1.9 KiB
Python

import cv2
import numpy as np
import torch
import torch.nn.functional as F
from utils.quantize import Dtype
from utils.common import load_image_cv, load_image_pil, init_service
def infer(img_data, web_service):
"""下位机推理
Args:
img_data (np.array): 推理数据
web_service: 实例化service
Returns:
推理结果信息
"""
rsp = web_service.infer(img_data)
infer_result0 = np.frombuffer(rsp['output_0.dat'], dtype=np.float32).flatten()
np_reshaped = infer_result0.reshape(1, 1000)
pred = torch.from_numpy(np_reshaped)
return pred
def post_process(rsp, image_src):
"""后处理
Args:
rsp : 预测框列表
image_src : 原图
Returns:
目标分类结果图
"""
infer_result0 = np.frombuffer(rsp['output_0.dat'], dtype=np.float32).flatten()
np_reshaped = infer_result0.reshape(1, 1000)
pred = torch.from_numpy(np_reshaped)
pred = F.softmax(pred, dim=1) # probabilities
# 获取前5名的置信度和类别
top5_conf, top5_labels = torch.topk(pred, k=5)
image_src = cv2.resize(image_src, (640, 640))
y_offset = 30 # 初始垂直偏移
# 在图像上绘制每个检测结果
for conf, cls_id in zip(top5_conf[0].numpy(), top5_labels[0].numpy()):
text = f"id: {cls_id}: {conf:.4f}"
cv2.putText(image_src, text, (20, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.7,
(255, 0, 0), 2)
y_offset += 30 # 每行下移30像素
return image_src
if __name__ == "__main__":
image_list = []
image_path = './models/resnet18/1.jpg'
image = load_image_cv(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (224, 224))
image_list.append(image)
web_service = init_service('./models/resnet18', Dtype.U8)
rsp = web_service.infer(image_list)
result_image = post_process(rsp, image)
cv2.imwrite('./models/resnet18/result.jpg', result_image)