更新灰度图像兼容,修复灰度图像示例工程失败的bug

This commit is contained in:
xujiao 2025-07-21 10:39:39 +08:00
parent 08ab84e8c7
commit 97620a1eec
1 changed files with 13 additions and 7 deletions

View File

@ -133,7 +133,7 @@ class Netrans():
data (dict): 加载的配置文件内容
**kwargs: 需要更新的参数
"""
grey = data['input_meta']['databases'][0]['ports'][0]['preprocess']['preproc_node_params'] == 'IMAGE_GRAY'
grey = data['input_meta']['databases'][0]['ports'][0]['preprocess']['preproc_node_params']['preproc_type'] == 'IMAGE_GRAY'
if 'mean' in kwargs:
mean = self._format_preprocess_param(kwargs['mean'], grey)
data = self._upload_config_mean(data, mean)
@ -141,7 +141,7 @@ class Netrans():
scale = self._format_preprocess_param(kwargs['scale'], grey)
data = self._upload_config_scale(data, scale)
if 'reverse_channel' in kwargs:
data = self._upload_config_reverse_channel(data, kwargs['reverse_channel'])
data = self._upload_config_reverse_channel(data, kwargs['reverse_channel'],grey)
return data
@ -195,9 +195,17 @@ class Netrans():
Returns:
list: 处理后的参数值
"""
if grey:
return param
return param if isinstance(param, list) else [param] * 3
ch = 1 if grey else 3
if isinstance(param, (int, float)):
return [float(param)] * ch
if isinstance(param, (list, tuple)):
if len(param) != ch:
raise ValueError(
f"灰度图需 1 个值RGB 图需 3 个值,"
f"当前通道数={ch},但提供 {len(param)} 个值"
)
return [float(v) for v in param]
raise TypeError("mean / scale 必须是数字或 list/tuple")
def _verify_preprocess_value(self):
"""单元测试中用于判断是否成功修改配置文件中的参数
@ -241,8 +249,6 @@ class Netrans():
TypeError: 仅支持量化成 uint8, int8, int16
"""
if quantize_type not in ['uint8', 'int8', 'int16']:
print(quantize_type)
print(type(quantize_type))
raise TypeError(f"不支持的量化类型: {quantize_type},仅支持 uint8, int8, int16")
self.quantize_type = quantize_type
Quantize(self).quantize_network()