92 lines
2.6 KiB
Python
92 lines
2.6 KiB
Python
import cv2
|
|
import numpy as np
|
|
from PIL import Image
|
|
import copy
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from utils.quantize import Dtype
|
|
from utils.common import load_image_cv, load_image_pil, init_service
|
|
|
|
colors = [(0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128),
|
|
(128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128),
|
|
(64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128),
|
|
(128, 64, 12)]
|
|
|
|
|
|
def resize_image(image, size):
|
|
"""输入图片缩放
|
|
|
|
Args:
|
|
image: 输入图片
|
|
size: 输出图片的尺寸
|
|
Returns:
|
|
缩放好的图片
|
|
"""
|
|
iw, ih = image.size
|
|
w, h = size
|
|
scale = min(w/iw, h/ih)
|
|
nw = int(iw*scale)
|
|
nh = int(ih*scale)
|
|
image = image.resize((nw,nh), Image.BICUBIC)
|
|
new_image = Image.new('RGB', size, (128,128,128))
|
|
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
|
|
|
|
return new_image, nw, nh
|
|
|
|
|
|
def post_process(rsp, nw, nh, orininal_w, orininal_h):
|
|
"""后处理
|
|
|
|
Args:
|
|
rsp: 推理的结果
|
|
nw (int):缩放后的宽
|
|
nh (int): 缩放后的高
|
|
orininal_w (int): 原图片的宽
|
|
orininal_h (int): 原图片的高
|
|
|
|
Returns:
|
|
结果图片
|
|
"""
|
|
infer_result = np.frombuffer(rsp['output_0.dat'], dtype=np.float32).flatten()
|
|
result_reshaped = infer_result.reshape(21, 512, 512)
|
|
pr = torch.from_numpy(result_reshaped)
|
|
pr = F.softmax(pr.permute(1, 2, 0), dim=-1).cpu().numpy()
|
|
pr = pr[int((512 - nh) // 2): int((512 - nh) // 2 + nh), \
|
|
int((512 - nw) // 2): int((512 - nw) // 2 + nw)]
|
|
pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation=cv2.INTER_LINEAR)
|
|
pr = pr.argmax(axis=-1)
|
|
seg_img = np.reshape(np.array(colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])
|
|
image = Image.fromarray(np.uint8(seg_img))
|
|
image = Image.blend(old_img, image, 0.7)
|
|
return image
|
|
|
|
|
|
if __name__ == "__main__":
|
|
img_path = './models/deeplab_mobilnet_v3/1.jpg'
|
|
image = load_image_pil(img_path)
|
|
old_img = copy.deepcopy(image)
|
|
orininal_h = np.array(image).shape[0]
|
|
orininal_w = np.array(image).shape[1]
|
|
image_data, nw, nh = resize_image(image, (512, 512))
|
|
image_data = np.array(image_data).reshape(1, 512, 512 * 3)
|
|
web_service = init_service('./models/deeplab_mobilnet_v3', Dtype.I16)
|
|
rsp = web_service.infer(image_data)
|
|
image_result = post_process(rsp, nw, nh, orininal_w, orininal_h)
|
|
image_result.save('./models/deeplab_mobilnet_v3/result.jpg')
|
|
print('deeplab推理完成')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|