demo_system_algorithm/test_src/demo_deeplabv3.py

92 lines
2.6 KiB
Python

import cv2
import numpy as np
from PIL import Image
import copy
import torch
import torch.nn.functional as F
from utils.quantize import Dtype
from utils.common import load_image_cv, load_image_pil, init_service
colors = [(0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128),
(128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128),
(64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128),
(128, 64, 12)]
def resize_image(image, size):
"""输入图片缩放
Args:
image: 输入图片
size: 输出图片的尺寸
Returns:
缩放好的图片
"""
iw, ih = image.size
w, h = size
scale = min(w/iw, h/ih)
nw = int(iw*scale)
nh = int(ih*scale)
image = image.resize((nw,nh), Image.BICUBIC)
new_image = Image.new('RGB', size, (128,128,128))
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
return new_image, nw, nh
def post_process(rsp, nw, nh, orininal_w, orininal_h):
"""后处理
Args:
rsp: 推理的结果
nw (int):缩放后的宽
nh (int): 缩放后的高
orininal_w (int): 原图片的宽
orininal_h (int): 原图片的高
Returns:
结果图片
"""
infer_result = np.frombuffer(rsp['output_0.dat'], dtype=np.float32).flatten()
result_reshaped = infer_result.reshape(21, 512, 512)
pr = torch.from_numpy(result_reshaped)
pr = F.softmax(pr.permute(1, 2, 0), dim=-1).cpu().numpy()
pr = pr[int((512 - nh) // 2): int((512 - nh) // 2 + nh), \
int((512 - nw) // 2): int((512 - nw) // 2 + nw)]
pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation=cv2.INTER_LINEAR)
pr = pr.argmax(axis=-1)
seg_img = np.reshape(np.array(colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])
image = Image.fromarray(np.uint8(seg_img))
image = Image.blend(old_img, image, 0.7)
return image
if __name__ == "__main__":
img_path = './models/deeplab_mobilnet_v3/1.jpg'
image = load_image_pil(img_path)
old_img = copy.deepcopy(image)
orininal_h = np.array(image).shape[0]
orininal_w = np.array(image).shape[1]
image_data, nw, nh = resize_image(image, (512, 512))
image_data = np.array(image_data).reshape(1, 512, 512 * 3)
web_service = init_service('./models/deeplab_mobilnet_v3', Dtype.I16)
rsp = web_service.infer(image_data)
image_result = post_process(rsp, nw, nh, orininal_w, orininal_h)
image_result.save('./models/deeplab_mobilnet_v3/result.jpg')
print('deeplab推理完成')