fixed some error
This commit is contained in:
parent
58ec0311b7
commit
8da8917e3b
|
@ -36,7 +36,7 @@ class Netrans():
|
|||
raise FileNotFoundError(f"Directory not found: {model_path}")
|
||||
self.model_path = os.path.abspath(model_path)
|
||||
self.model_name = os.path.basename(self.model_path)
|
||||
self.set_netrans(netrans)
|
||||
self._set_netrans_path(netrans)
|
||||
|
||||
def model2nbg(self, quantize_type, inputmeta=False, **kargs):
|
||||
"""
|
||||
|
@ -53,7 +53,7 @@ class Netrans():
|
|||
self.export(**kargs)
|
||||
|
||||
|
||||
def get_os_netrans_path(self):
|
||||
def _get_os_netrans_path(self):
|
||||
"""
|
||||
获取系统环境变量中的 NETRANS_PATH。
|
||||
|
||||
|
@ -62,9 +62,10 @@ class Netrans():
|
|||
"""
|
||||
return os.environ.get('NETRANS_PATH')
|
||||
|
||||
def set_netrans(self, netrans_path=None):
|
||||
def _set_netrans_path(self, netrans_path=None):
|
||||
"""
|
||||
设置 Netrans 路径。
|
||||
如果未设置环境变量 NETRANS_PATH,则可以通过此参数指定。
|
||||
|
||||
Args:
|
||||
netrans_path (str, optional): 如果未设置环境变量 NETRANS_PATH,则可以通过此参数指定。
|
||||
|
@ -72,7 +73,7 @@ class Netrans():
|
|||
if netrans_path is not None :
|
||||
netrans_path = os.path.abspath(netrans_path)
|
||||
else :
|
||||
netrans_path = self.get_os_netrans_path()
|
||||
netrans_path = self._get_os_netrans_path()
|
||||
if not os.path.exists(netrans_path):
|
||||
raise FileExistsError('未找到 Netrans 路径,请设置 NETRANS_PATH 或指定 netrans_path 参数')
|
||||
self.netrans = os.path.join(netrans_path, 'pnnacc')
|
||||
|
@ -80,7 +81,11 @@ class Netrans():
|
|||
|
||||
def config(self, inputmeta=False, **kwargs):
|
||||
"""
|
||||
配置模型转换参数
|
||||
用户处理inputmate的入口,和shell一致所以叫config.
|
||||
根据用户的实际场景,设置inputmeta参数swith对应的分支
|
||||
False: 生成inputmeta
|
||||
True:使用原本的inputmeta
|
||||
str:使用指定的inputmeta
|
||||
|
||||
Args:
|
||||
inputmeta (bool or str, optional): 是否更新模型转换配置参数。
|
||||
|
@ -96,27 +101,18 @@ class Netrans():
|
|||
self.input_meta = inputmeta
|
||||
elif isinstance(inputmeta, bool):
|
||||
if inputmeta is False :
|
||||
self.inputmeta_gen()
|
||||
self._config_gen_inputmeta_file()
|
||||
else :
|
||||
raise ValueError("inputmeta 参数无效,请设置为 False 或指定配置文件路径")
|
||||
if not os.path.exists(self.input_meta):
|
||||
raise FileExistsError(f"未找到配置文件: {self.input_meta}")
|
||||
if kwargs:
|
||||
self.update_config(**kwargs)
|
||||
# if len(kargs) == 0 : return
|
||||
# if kargs['mean']==0 and kargs['scale'] ==1 : return
|
||||
# if isinstance(kargs['mean'], list) or isinstance(kargs['scale'], (int, float)) or isinstance(kargs['reverse_channel'], bool):
|
||||
# with open(self.input_meta,'r') as f :
|
||||
# yaml = YAML()
|
||||
# data = yaml.load(f)
|
||||
# data = self.upload_cfg(data ,**kargs)
|
||||
# with open(self.input_meta,'w') as f :
|
||||
# yaml = YAML()
|
||||
# yaml.dump(data, f)
|
||||
self._update_config(**kwargs)
|
||||
|
||||
def update_config(self, **kwargs):
|
||||
def _update_config(self, **kwargs):
|
||||
"""
|
||||
更新配置文件中的参数。
|
||||
如果用户通过kwargs[配置预处理参数,则调用该函数更新配置文件中的参数。
|
||||
包括文件读写和更新
|
||||
|
||||
Args:
|
||||
kwargs (dict): 包含需要更新的参数,如 mean、scale、reverse_channel 等。
|
||||
|
@ -124,12 +120,12 @@ class Netrans():
|
|||
with open(self.input_meta, 'r') as f:
|
||||
yaml = YAML()
|
||||
data = yaml.load(f)
|
||||
data = self.upload_cfg(data, **kwargs)
|
||||
data = self._update_config_data(data, **kwargs)
|
||||
with open(self.input_meta, 'w') as f:
|
||||
yaml.dump(data, f)
|
||||
|
||||
|
||||
def upload_cfg(self, data, **kwargs):
|
||||
def _update_config_data(self, data, **kwargs):
|
||||
"""
|
||||
更新配置文件中的参数。
|
||||
|
||||
|
@ -139,17 +135,17 @@ class Netrans():
|
|||
"""
|
||||
grey = data['input_meta']['databases'][0]['ports'][0]['preprocess']['preproc_node_params'] == 'IMAGE_GRAY'
|
||||
if 'mean' in kwargs:
|
||||
mean = self.handle_param(kwargs['mean'], grey)
|
||||
data = self.upload_cfg_mean(data, mean)
|
||||
mean = self._format_preprocess_param(kwargs['mean'], grey)
|
||||
data = self._upload_config_mean(data, mean)
|
||||
if 'scale' in kwargs:
|
||||
scale = self.handle_param(kwargs['scale'], grey)
|
||||
data = self.upload_cfg_scale(data, scale)
|
||||
scale = self._format_preprocess_param(kwargs['scale'], grey)
|
||||
data = self._upload_config_scale(data, scale)
|
||||
if 'reverse_channel' in kwargs:
|
||||
data = self.upload_cfg_reverse_channel(data, kwargs['reverse_channel'])
|
||||
data = self._upload_config_reverse_channel(data, kwargs['reverse_channel'])
|
||||
return data
|
||||
|
||||
|
||||
def upload_cfg_mean(self, data, mean):
|
||||
def _upload_config_mean(self, data, mean):
|
||||
"""
|
||||
更新配置文件中的mean值
|
||||
|
||||
|
@ -160,7 +156,7 @@ class Netrans():
|
|||
for db in data['input_meta']['databases']:
|
||||
db['ports'][0]['preprocess']['mean'] = mean
|
||||
return data
|
||||
def upload_cfg_scale(self, data, scale):
|
||||
def _upload_config_scale(self, data, scale):
|
||||
"""
|
||||
scale
|
||||
|
||||
|
@ -172,7 +168,7 @@ class Netrans():
|
|||
db['ports'][0]['preprocess']['scale'] = scale
|
||||
return data
|
||||
|
||||
def upload_cfg_reverse_channel(self, data, reverse_channel):
|
||||
def _upload_config_reverse_channel(self, data, reverse_channel):
|
||||
"""
|
||||
更新配置文件中的reverse_channel
|
||||
|
||||
|
@ -184,8 +180,12 @@ class Netrans():
|
|||
db['ports'][0]['preprocess']['reverse_channel'] = reverse_channel
|
||||
return data
|
||||
|
||||
def handle_param(self, param, grey=False):
|
||||
def _format_preprocess_param(self, param, grey=False):
|
||||
"""
|
||||
用于 update model config.
|
||||
在模型预处理参数更新的时候,灰度图像仅有一个C,而RGB图像存在三个 channel,
|
||||
因此,用户输入的 scale 和 mean 为一个值的时候,需要将其转换成列表
|
||||
同时根据图像类型调整为 list.length() == channel
|
||||
处理参数,根据图像类型调整参数格式。
|
||||
|
||||
Args:
|
||||
|
@ -199,7 +199,7 @@ class Netrans():
|
|||
return param
|
||||
return param if isinstance(param, list) else [param] * 3
|
||||
|
||||
def read_input_meta_data(self):
|
||||
def _verify_preprocess_value(self):
|
||||
"""单元测试中用于判断是否成功修改配置文件中的参数
|
||||
|
||||
Returns:
|
||||
|
@ -223,7 +223,7 @@ class Netrans():
|
|||
func = ImportModel(self)
|
||||
func.import_network()
|
||||
|
||||
def inputmeta_gen(self):
|
||||
def _config_gen_inputmeta_file(self):
|
||||
"""
|
||||
自动生成配置文件
|
||||
"""
|
||||
|
@ -240,7 +240,9 @@ class Netrans():
|
|||
Raises:
|
||||
TypeError: 仅支持量化成 uint8, int8, int16
|
||||
"""
|
||||
if quantize_type not in ['unit8', 'int8', 'int16']:
|
||||
if quantize_type not in ['uint8', 'int8', 'int16']:
|
||||
print(quantize_type)
|
||||
print(type(quantize_type))
|
||||
raise TypeError(f"不支持的量化类型: {quantize_type},仅支持 uint8, int8, int16")
|
||||
self.quantize_type = quantize_type
|
||||
Quantize(self).quantize_network()
|
||||
|
@ -260,5 +262,5 @@ class Netrans():
|
|||
if __name__ == '__main__':
|
||||
network = '../../model_zoo/yolov4_tiny'
|
||||
yolo = Netrans(network)
|
||||
yolo.inputmeta_gen()
|
||||
yolo._config_gen_inputmeta_file()
|
||||
yolo.model2nbg("uint8")
|
Loading…
Reference in New Issue