273 lines
9.7 KiB
Python
273 lines
9.7 KiB
Python
import sys, os
|
||
import subprocess
|
||
import warnings
|
||
|
||
from ruamel.yaml import YAML
|
||
from ruamel import yaml
|
||
import file_model
|
||
from import_model import ImportModel
|
||
from quantize import Quantize
|
||
from export import Export
|
||
from config import Config
|
||
from utils import check_path
|
||
|
||
# 忽略 ruamel.yaml 的安全加载警告
|
||
warnings.simplefilter('ignore', yaml.error.UnsafeLoaderWarning)
|
||
class Netrans():
|
||
"""Netrans Python API,用于模型转换和量化操作。
|
||
|
||
提供模型加载、配置、量化和导出等功能。
|
||
"""
|
||
|
||
def __init__(self, model_path, netrans=None, verbose=False):
|
||
"""
|
||
初始化Netrans
|
||
|
||
Args:
|
||
model_path (str) : 要进行编译转换的模型工程目录.
|
||
netrans (str) : 在没有安装 Netrans 的情况下,指定 Netrans 路径。默认为 None。
|
||
verbose (bool, optional): 是否启用详细模式。默认为 False。
|
||
|
||
Returns :
|
||
None
|
||
"""
|
||
self.verbose = verbose
|
||
if not os.path.exists(model_path):
|
||
raise FileNotFoundError(f"Directory not found: {model_path}")
|
||
self.model_path = os.path.abspath(model_path)
|
||
self.model_name = os.path.basename(self.model_path)
|
||
self._set_netrans_path(netrans)
|
||
|
||
def model2nbg(self, quantize_type, inputmeta=False, **kargs):
|
||
"""
|
||
模型快速转换成NBG
|
||
|
||
Args:
|
||
quantize_type (_type_): 量化类型,支持 uint8, int8, int16。
|
||
inputmeta (bool, optional): 是否进行参数配置。默认为 False。
|
||
**kwargs: 其他可选参数。
|
||
"""
|
||
self.load()
|
||
self.config(inputmeta, **kargs)
|
||
self.quantize(quantize_type, **kargs)
|
||
self.export(**kargs)
|
||
|
||
|
||
def _get_os_netrans_path(self):
|
||
"""
|
||
获取系统环境变量中的 NETRANS_PATH。
|
||
|
||
Returns:
|
||
str: 如果存在 NETRANS_PATH,则返回路径;否则返回 None
|
||
"""
|
||
return os.environ.get('NETRANS_PATH')
|
||
|
||
def _set_netrans_path(self, netrans_path=None):
|
||
"""
|
||
设置 Netrans 路径。
|
||
如果未设置环境变量 NETRANS_PATH,则可以通过此参数指定。
|
||
|
||
Args:
|
||
netrans_path (str, optional): 如果未设置环境变量 NETRANS_PATH,则可以通过此参数指定。
|
||
"""
|
||
if netrans_path is not None :
|
||
netrans_path = os.path.abspath(netrans_path)
|
||
else :
|
||
netrans_path = self._get_os_netrans_path()
|
||
if not os.path.exists(netrans_path):
|
||
raise FileExistsError('未找到 Netrans 路径,请设置 NETRANS_PATH 或指定 netrans_path 参数')
|
||
self.netrans = os.path.join(netrans_path, 'pnnacc')
|
||
self.netrans_path = netrans_path
|
||
|
||
def config(self, inputmeta=False, **kwargs):
|
||
"""
|
||
用户处理inputmate的入口,和shell一致所以叫config.
|
||
根据用户的实际场景,设置inputmeta参数swith对应的分支
|
||
False: 生成inputmeta
|
||
True:使用原本的inputmeta
|
||
str:使用指定的inputmeta
|
||
|
||
Args:
|
||
inputmeta (bool or str, optional): 是否更新模型转换配置参数。
|
||
- 如果为 False,则自动生成配置文件。
|
||
- 如果为字符串,则直接使用指定的配置文件路径。
|
||
**kwargs: 其他可选参数,如 mean、scale、reverse_channel 等。
|
||
Raises:
|
||
FileNotFoundError: 没有找到指定的模型转换配置文件,请重新生成
|
||
FileExistsError: 没有找到指定的模型转换配置文件,请重新生成
|
||
"""
|
||
self.input_meta = os.path.join(self.model_path,'%s%s'%(self.model_name, file_model.extensions.input_meta))
|
||
if isinstance(inputmeta, str):
|
||
self.input_meta = inputmeta
|
||
elif isinstance(inputmeta, bool):
|
||
if inputmeta is False :
|
||
self._config_gen_inputmeta_file()
|
||
else :
|
||
raise ValueError("inputmeta 参数无效,请设置为 False 或指定配置文件路径")
|
||
if not os.path.exists(self.input_meta):
|
||
raise FileExistsError(f"未找到配置文件: {self.input_meta}")
|
||
if kwargs:
|
||
self._update_config(**kwargs)
|
||
|
||
def _update_config(self, **kwargs):
|
||
"""
|
||
如果用户通过kwargs[配置预处理参数,则调用该函数更新配置文件中的参数。
|
||
包括文件读写和更新
|
||
|
||
Args:
|
||
kwargs (dict): 包含需要更新的参数,如 mean、scale、reverse_channel 等。
|
||
"""
|
||
with open(self.input_meta, 'r') as f:
|
||
yaml = YAML()
|
||
data = yaml.load(f)
|
||
data = self._update_config_data(data, **kwargs)
|
||
with open(self.input_meta, 'w') as f:
|
||
yaml.dump(data, f)
|
||
|
||
|
||
def _update_config_data(self, data, **kwargs):
|
||
"""
|
||
更新配置文件中的参数。
|
||
|
||
Args:
|
||
data (dict): 加载的配置文件内容。
|
||
**kwargs: 需要更新的参数。
|
||
"""
|
||
grey = data['input_meta']['databases'][0]['ports'][0]['preprocess']['preproc_node_params']['preproc_type'] == 'IMAGE_GRAY'
|
||
if 'mean' in kwargs:
|
||
mean = self._format_preprocess_param(kwargs['mean'], grey)
|
||
data = self._upload_config_mean(data, mean)
|
||
if 'scale' in kwargs:
|
||
scale = self._format_preprocess_param(kwargs['scale'], grey)
|
||
data = self._upload_config_scale(data, scale)
|
||
if 'reverse_channel' in kwargs:
|
||
data = self._upload_config_reverse_channel(data, kwargs['reverse_channel'])
|
||
return data
|
||
|
||
|
||
def _upload_config_mean(self, data, mean):
|
||
"""
|
||
更新配置文件中的mean值
|
||
|
||
Args:
|
||
data (yaml): yaml.load 加载的配置文件
|
||
mean (list): 需要更新的mean值
|
||
"""
|
||
for db in data['input_meta']['databases']:
|
||
db['ports'][0]['preprocess']['mean'] = mean
|
||
return data
|
||
def _upload_config_scale(self, data, scale):
|
||
"""
|
||
scale
|
||
|
||
Args:
|
||
data (yaml): yaml.load 加载的配置文件
|
||
scale (list): 需要更新的 scale 值
|
||
"""
|
||
for db in data['input_meta']['databases']:
|
||
db['ports'][0]['preprocess']['scale'] = scale
|
||
return data
|
||
|
||
def _upload_config_reverse_channel(self, data, reverse_channel):
|
||
"""
|
||
更新配置文件中的reverse_channel
|
||
|
||
Args:
|
||
data (yaml): yaml.load 加载的配置文件
|
||
reverse_channel (bool): 需要更新的reverse_channel
|
||
"""
|
||
for db in data['input_meta']['databases']:
|
||
db['ports'][0]['preprocess']['reverse_channel'] = reverse_channel
|
||
return data
|
||
|
||
def _format_preprocess_param(self, param, grey=False):
|
||
"""
|
||
用于 update model config.
|
||
在模型预处理参数更新的时候,灰度图像仅有一个C,而RGB图像存在三个 channel,
|
||
因此,用户输入的 scale 和 mean 为一个值的时候,需要将其转换成列表
|
||
同时根据图像类型调整为 list.length() == channel
|
||
处理参数,根据图像类型调整参数格式。
|
||
|
||
Args:
|
||
param: 参数值,可以是单个值或列表。
|
||
grey (bool, optional): 是否为灰度图像。默认为 False。
|
||
|
||
Returns:
|
||
list: 处理后的参数值。
|
||
"""
|
||
ch = 1 if grey else 3
|
||
if isinstance(param, (int, float)):
|
||
return [float(param)] * ch
|
||
if isinstance(param, (list, tuple)):
|
||
if len(param) != ch:
|
||
raise ValueError(
|
||
f"灰度图需 1 个值,RGB 图需 3 个值,"
|
||
f"当前通道数={ch},但提供 {len(param)} 个值"
|
||
)
|
||
return [float(v) for v in param]
|
||
raise TypeError("mean / scale 必须是数字或 list/tuple")
|
||
|
||
def _verify_preprocess_value(self):
|
||
"""单元测试中用于判断是否成功修改配置文件中的参数
|
||
|
||
Returns:
|
||
dict : 获取配置文件中的参数
|
||
"""
|
||
with open(self.input_meta,'r') as f :
|
||
yaml = YAML()
|
||
data = yaml.load(f)
|
||
res = {}
|
||
for db in data['input_meta']['databases']:
|
||
res['scale'] = db['ports'][0]['preprocess']['scale']
|
||
res['mean'] = db['ports'][0]['preprocess']['mean']
|
||
res['reverse_channel'] = db['ports'][0]['preprocess']['reverse_channel']
|
||
return res
|
||
|
||
|
||
def load(self):
|
||
"""
|
||
加载模型
|
||
"""
|
||
func = ImportModel(self)
|
||
func.import_network()
|
||
|
||
def _config_gen_inputmeta_file(self):
|
||
"""
|
||
自动生成配置文件
|
||
"""
|
||
func = Config(self)
|
||
func.inputmeta_gen()
|
||
|
||
|
||
def quantize(self, quantize_type,**kargs):
|
||
"""
|
||
量化模型
|
||
|
||
Args:
|
||
quantize_type (_type_): 量化类型,支持 uint8, int8, int16
|
||
|
||
Raises:
|
||
TypeError: 仅支持量化成 uint8, int8, int16
|
||
"""
|
||
if quantize_type not in ['uint8', 'int8', 'int16']:
|
||
raise TypeError(f"不支持的量化类型: {quantize_type},仅支持 uint8, int8, int16")
|
||
self.quantize_type = quantize_type
|
||
Quantize(self).quantize_network()
|
||
|
||
def export(self, **kwargs):
|
||
"""模型导出
|
||
"""
|
||
if 'quantize_type' in kwargs:
|
||
self.quantize_type = kwargs['quantize_type']
|
||
if 'profile' in kwargs:
|
||
self.profile = kwargs['profile']
|
||
else:
|
||
self.profile = False
|
||
Export(self).export_network()
|
||
|
||
# 示例用法
|
||
if __name__ == '__main__':
|
||
network = '../../model_zoo/yolov4_tiny'
|
||
yolo = Netrans(network)
|
||
yolo._config_gen_inputmeta_file()
|
||
yolo.model2nbg("uint8") |