netrans/netrans_py/netrans.py

273 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sys, os
import subprocess
import warnings
from ruamel.yaml import YAML
from ruamel import yaml
import file_model
from import_model import ImportModel
from quantize import Quantize
from export import Export
from config import Config
from utils import check_path
# 忽略 ruamel.yaml 的安全加载警告
warnings.simplefilter('ignore', yaml.error.UnsafeLoaderWarning)
class Netrans():
"""Netrans Python API用于模型转换和量化操作。
提供模型加载、配置、量化和导出等功能。
"""
def __init__(self, model_path, netrans=None, verbose=False):
"""
初始化Netrans
Args:
model_path (str) : 要进行编译转换的模型工程目录.
netrans (str) : 在没有安装 Netrans 的情况下,指定 Netrans 路径。默认为 None。
verbose (bool, optional): 是否启用详细模式。默认为 False。
Returns :
None
"""
self.verbose = verbose
if not os.path.exists(model_path):
raise FileNotFoundError(f"Directory not found: {model_path}")
self.model_path = os.path.abspath(model_path)
self.model_name = os.path.basename(self.model_path)
self._set_netrans_path(netrans)
def model2nbg(self, quantize_type, inputmeta=False, **kargs):
"""
模型快速转换成NBG
Args:
quantize_type (_type_): 量化类型,支持 uint8, int8, int16。
inputmeta (bool, optional): 是否进行参数配置。默认为 False。
**kwargs: 其他可选参数。
"""
self.load()
self.config(inputmeta, **kargs)
self.quantize(quantize_type, **kargs)
self.export(**kargs)
def _get_os_netrans_path(self):
"""
获取系统环境变量中的 NETRANS_PATH。
Returns:
str: 如果存在 NETRANS_PATH则返回路径否则返回 None
"""
return os.environ.get('NETRANS_PATH')
def _set_netrans_path(self, netrans_path=None):
"""
设置 Netrans 路径。
如果未设置环境变量 NETRANS_PATH则可以通过此参数指定。
Args:
netrans_path (str, optional): 如果未设置环境变量 NETRANS_PATH则可以通过此参数指定。
"""
if netrans_path is not None :
netrans_path = os.path.abspath(netrans_path)
else :
netrans_path = self._get_os_netrans_path()
if not os.path.exists(netrans_path):
raise FileExistsError('未找到 Netrans 路径,请设置 NETRANS_PATH 或指定 netrans_path 参数')
self.netrans = os.path.join(netrans_path, 'pnnacc')
self.netrans_path = netrans_path
def config(self, inputmeta=False, **kwargs):
"""
用户处理inputmate的入口,和shell一致所以叫config.
根据用户的实际场景,设置inputmeta参数swith对应的分支
False: 生成inputmeta
True:使用原本的inputmeta
str:使用指定的inputmeta
Args:
inputmeta (bool or str, optional): 是否更新模型转换配置参数。
- 如果为 False则自动生成配置文件。
- 如果为字符串,则直接使用指定的配置文件路径。
**kwargs: 其他可选参数,如 mean、scale、reverse_channel 等。
Raises:
FileNotFoundError: 没有找到指定的模型转换配置文件,请重新生成
FileExistsError: 没有找到指定的模型转换配置文件,请重新生成
"""
self.input_meta = os.path.join(self.model_path,'%s%s'%(self.model_name, file_model.extensions.input_meta))
if isinstance(inputmeta, str):
self.input_meta = inputmeta
elif isinstance(inputmeta, bool):
if inputmeta is False :
self._config_gen_inputmeta_file()
else :
raise ValueError("inputmeta 参数无效,请设置为 False 或指定配置文件路径")
if not os.path.exists(self.input_meta):
raise FileExistsError(f"未找到配置文件: {self.input_meta}")
if kwargs:
self._update_config(**kwargs)
def _update_config(self, **kwargs):
"""
如果用户通过kwargs[配置预处理参数,则调用该函数更新配置文件中的参数。
包括文件读写和更新
Args:
kwargs (dict): 包含需要更新的参数,如 mean、scale、reverse_channel 等。
"""
with open(self.input_meta, 'r') as f:
yaml = YAML()
data = yaml.load(f)
data = self._update_config_data(data, **kwargs)
with open(self.input_meta, 'w') as f:
yaml.dump(data, f)
def _update_config_data(self, data, **kwargs):
"""
更新配置文件中的参数。
Args:
data (dict): 加载的配置文件内容。
**kwargs: 需要更新的参数。
"""
grey = data['input_meta']['databases'][0]['ports'][0]['preprocess']['preproc_node_params']['preproc_type'] == 'IMAGE_GRAY'
if 'mean' in kwargs:
mean = self._format_preprocess_param(kwargs['mean'], grey)
data = self._upload_config_mean(data, mean)
if 'scale' in kwargs:
scale = self._format_preprocess_param(kwargs['scale'], grey)
data = self._upload_config_scale(data, scale)
if 'reverse_channel' in kwargs:
data = self._upload_config_reverse_channel(data, kwargs['reverse_channel'])
return data
def _upload_config_mean(self, data, mean):
"""
更新配置文件中的mean值
Args:
data (yaml): yaml.load 加载的配置文件
mean (list): 需要更新的mean值
"""
for db in data['input_meta']['databases']:
db['ports'][0]['preprocess']['mean'] = mean
return data
def _upload_config_scale(self, data, scale):
"""
scale
Args:
data (yaml): yaml.load 加载的配置文件
scale (list): 需要更新的 scale 值
"""
for db in data['input_meta']['databases']:
db['ports'][0]['preprocess']['scale'] = scale
return data
def _upload_config_reverse_channel(self, data, reverse_channel):
"""
更新配置文件中的reverse_channel
Args:
data (yaml): yaml.load 加载的配置文件
reverse_channel (bool): 需要更新的reverse_channel
"""
for db in data['input_meta']['databases']:
db['ports'][0]['preprocess']['reverse_channel'] = reverse_channel
return data
def _format_preprocess_param(self, param, grey=False):
"""
用于 update model config.
在模型预处理参数更新的时候,灰度图像仅有一个C,而RGB图像存在三个 channel,
因此,用户输入的 scale 和 mean 为一个值的时候,需要将其转换成列表
同时根据图像类型调整为 list.length() == channel
处理参数,根据图像类型调整参数格式。
Args:
param: 参数值,可以是单个值或列表。
grey (bool, optional): 是否为灰度图像。默认为 False。
Returns:
list: 处理后的参数值。
"""
ch = 1 if grey else 3
if isinstance(param, (int, float)):
return [float(param)] * ch
if isinstance(param, (list, tuple)):
if len(param) != ch:
raise ValueError(
f"灰度图需 1 个值RGB 图需 3 个值,"
f"当前通道数={ch},但提供 {len(param)} 个值"
)
return [float(v) for v in param]
raise TypeError("mean / scale 必须是数字或 list/tuple")
def _verify_preprocess_value(self):
"""单元测试中用于判断是否成功修改配置文件中的参数
Returns:
dict : 获取配置文件中的参数
"""
with open(self.input_meta,'r') as f :
yaml = YAML()
data = yaml.load(f)
res = {}
for db in data['input_meta']['databases']:
res['scale'] = db['ports'][0]['preprocess']['scale']
res['mean'] = db['ports'][0]['preprocess']['mean']
res['reverse_channel'] = db['ports'][0]['preprocess']['reverse_channel']
return res
def load(self):
"""
加载模型
"""
func = ImportModel(self)
func.import_network()
def _config_gen_inputmeta_file(self):
"""
自动生成配置文件
"""
func = Config(self)
func.inputmeta_gen()
def quantize(self, quantize_type,**kargs):
"""
量化模型
Args:
quantize_type (_type_): 量化类型,支持 uint8, int8, int16
Raises:
TypeError: 仅支持量化成 uint8, int8, int16
"""
if quantize_type not in ['uint8', 'int8', 'int16']:
raise TypeError(f"不支持的量化类型: {quantize_type},仅支持 uint8, int8, int16")
self.quantize_type = quantize_type
Quantize(self).quantize_network()
def export(self, **kwargs):
"""模型导出
"""
if 'quantize_type' in kwargs:
self.quantize_type = kwargs['quantize_type']
if 'profile' in kwargs:
self.profile = kwargs['profile']
else:
self.profile = False
Export(self).export_network()
# 示例用法
if __name__ == '__main__':
network = '../../model_zoo/yolov4_tiny'
yolo = Netrans(network)
yolo._config_gen_inputmeta_file()
yolo.model2nbg("uint8")