forked from nudt_dsp/netrans
新增netrans注释
This commit is contained in:
parent
3989aa42b9
commit
80e995b9a9
|
@ -2,13 +2,30 @@
|
|||
import os
|
||||
import sys
|
||||
from utils import check_path, AttributeCopier, create_cls
|
||||
import subprocess
|
||||
|
||||
class Config(AttributeCopier):
|
||||
"""从实例化的 Netrans 中解析模型参数,并基于pnnacc 生成配置文件模板
|
||||
|
||||
Args:
|
||||
Netrans (class): 实例化的Netrans类,包含 模型信息 和 Netrans 信息
|
||||
"""
|
||||
def __init__(self, source_obj) -> None:
|
||||
"""从实例化的 Netrans 中解析模型参数
|
||||
|
||||
Args:
|
||||
source_obj (class): 实例化的Netrans类,包含 模型信息 和 Netrans 信息
|
||||
|
||||
"""
|
||||
super().__init__(source_obj)
|
||||
|
||||
@check_path
|
||||
def inputmeta_gen(self):
|
||||
"""生成配置文件模板
|
||||
|
||||
Return:
|
||||
None
|
||||
"""
|
||||
netrans_path = self.netrans
|
||||
network_name = self.model_name
|
||||
# 进入网络名称指定的目录
|
||||
|
@ -16,23 +33,28 @@ class Config(AttributeCopier):
|
|||
# check_env(network_name)
|
||||
|
||||
# 执行 pegasus 命令
|
||||
os.system(f"{netrans_path} generate inputmeta --model {network_name}.json --separated-database")
|
||||
cmd = f"{netrans_path} generate inputmeta --model {network_name}.json --separated-database"
|
||||
try :
|
||||
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
||||
except :
|
||||
raise RuntimeError('config failed')
|
||||
# os.chdir("..")
|
||||
|
||||
def main():
|
||||
# 检查命令行参数数量是否正确
|
||||
if len(sys.argv) != 2:
|
||||
print("Enter a network name!")
|
||||
sys.exit(2)
|
||||
# def main():
|
||||
|
||||
# # 检查命令行参数数量是否正确
|
||||
# if len(sys.argv) != 2:
|
||||
# print("Enter a network name!")
|
||||
# sys.exit(2)
|
||||
|
||||
# 检查提供的目录是否存在
|
||||
network_name = sys.argv[1]
|
||||
# 构建 netrans 可执行文件的路径
|
||||
netrans_path =os.getenv('NETRANS_PATH')
|
||||
cla = create_cls(netrans_path, network_name)
|
||||
func = InputmetaGen(cla)
|
||||
func.inputmeta_gen()
|
||||
# # 检查提供的目录是否存在
|
||||
# network_name = sys.argv[1]
|
||||
# # 构建 netrans 可执行文件的路径
|
||||
# netrans_path =os.getenv('NETRANS_PATH')
|
||||
# cla = create_cls(netrans_path, network_name)
|
||||
# func = InputmetaGen(cla)
|
||||
# func.inputmeta_gen()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
# if __name__ == '__main__':
|
||||
# main()
|
|
@ -1,95 +0,0 @@
|
|||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
from utils import check_path, AttributeCopier, creat_cla
|
||||
|
||||
class Infer(AttributeCopier):
|
||||
def __init__(self, source_obj) -> None:
|
||||
super().__init__(source_obj)
|
||||
|
||||
@check_path
|
||||
def inference_network(self):
|
||||
netrans = self.netrans
|
||||
quantized = self.quantize_type
|
||||
name = self.model_name
|
||||
# print(self.__dict__)
|
||||
|
||||
netrans += " dump"
|
||||
# 进入模型目录
|
||||
|
||||
# 定义类型和量化类型
|
||||
if quantized == 'float':
|
||||
type_ = 'float32'
|
||||
quantization_type = 'float32'
|
||||
elif quantized == 'uint8':
|
||||
quantization_type = 'asymmetric_affine'
|
||||
type_ = 'quantized'
|
||||
elif quantized == 'int8':
|
||||
quantization_type = 'dynamic_fixed_point-8'
|
||||
type_ = 'quantized'
|
||||
elif quantized == 'int16':
|
||||
quantization_type = 'dynamic_fixed_point-16'
|
||||
type_ = 'quantized'
|
||||
else:
|
||||
print("=========== wrong quantization_type ! ( float / uint8 / int8 / int16 )===========")
|
||||
sys.exit(-1)
|
||||
|
||||
# 构建推理命令
|
||||
inf_path = './inf'
|
||||
cmd = f"{netrans} \
|
||||
--dtype {type_} \
|
||||
--batch-size 1 \
|
||||
--model-quantize {name}_{quantization_type}.quantize \
|
||||
--model {name}.json \
|
||||
--model-data {name}.data \
|
||||
--output-dir {inf_path} \
|
||||
--with-input-meta {name}_inputmeta.yml \
|
||||
--device CPU"
|
||||
|
||||
# 执行推理命令
|
||||
if self.verbose is True:
|
||||
print(cmd)
|
||||
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
||||
|
||||
# 检查执行结果
|
||||
if result.returncode == 0:
|
||||
print("\033[32m SUCCESS \033[0m")
|
||||
else:
|
||||
print(f"\033[31m ERROR: {result.stderr} \033[0m")
|
||||
|
||||
# 返回原始目录
|
||||
|
||||
def main():
|
||||
# 检查命令行参数数量
|
||||
if len(sys.argv) < 3:
|
||||
print("Input a network name and quantized type ( float / uint8 / int8 / int16 )")
|
||||
sys.exit(-1)
|
||||
|
||||
# 检查网络目录是否存在
|
||||
network_name = sys.argv[1]
|
||||
if not os.path.exists(network_name):
|
||||
print(f"Directory {network_name} does not exist !")
|
||||
sys.exit(-2)
|
||||
# print("here")
|
||||
# 定义 netrans 路径
|
||||
# netrans = os.path.join(os.environ['NETRANS_PATH'], 'pnnacc')
|
||||
network_name = sys.argv[1]
|
||||
# check_env(network_name)
|
||||
|
||||
netrans_path = os.environ['NETRANS_PATH']
|
||||
# netrans = os.path.join(netrans_path, 'pnnacc')
|
||||
quantize_type = sys.argv[2]
|
||||
cla = creat_cla(netrans_path, network_name,quantize_type,False)
|
||||
|
||||
# 调用量化函数
|
||||
func = Infer(cla)
|
||||
func.inference_network()
|
||||
|
||||
# 定义数据集文件路径
|
||||
# dataset_path = './dataset.txt'
|
||||
# 调用推理函数
|
||||
# inference_network(network_name, sys.argv[2])
|
||||
|
||||
if __name__ == '__main__':
|
||||
# print("main")
|
||||
main()
|
|
@ -1 +0,0 @@
|
|||
../netrans_cli/example.py
|
|
@ -9,11 +9,24 @@ from utils import check_path, AttributeCopier, create_cls
|
|||
dataset = 'dataset.txt'
|
||||
|
||||
class Export(AttributeCopier):
|
||||
"""从实例化的 Netrans 中解析模型参数,并基于 pnnacc 导出模型ngb文件
|
||||
|
||||
Args:
|
||||
Netrans (class): 实例化的Netrans类,包含 模型信息 和 Netrans 信息
|
||||
"""
|
||||
def __init__(self, source_obj) -> None:
|
||||
"""从实例化的 Netrans 中解析模型参数
|
||||
|
||||
Args:
|
||||
source_obj (class): 实例化的Netrans类,包含 模型信息 和 Netrans 信息
|
||||
|
||||
"""
|
||||
super().__init__(source_obj)
|
||||
|
||||
@check_path
|
||||
def export_network(self):
|
||||
"""基于 pnnacc 导出模型
|
||||
"""
|
||||
|
||||
netrans = self.netrans
|
||||
quantized = self.quantize_type
|
||||
|
|
|
@ -4,6 +4,11 @@ import subprocess
|
|||
from utils import check_path, AttributeCopier, create_cls
|
||||
|
||||
def check_status(result):
|
||||
"""解析命令执行情况
|
||||
|
||||
Args:
|
||||
result (return of subprocrss.run): subprocess.run的返回值
|
||||
"""
|
||||
if result.returncode == 0:
|
||||
print("\033[31m LOAD MODEL SUCCESS \033[0m")
|
||||
else:
|
||||
|
@ -11,6 +16,15 @@ def check_status(result):
|
|||
|
||||
|
||||
def import_caffe_network(name, netrans_path):
|
||||
"""导入 caffe 模型
|
||||
|
||||
Args:
|
||||
name (str): 模型名字
|
||||
netrans_path (str): 模型路径
|
||||
|
||||
Returns:
|
||||
cmd (str): 生成的pnnacc 命令行, 被subprocesses执行
|
||||
"""
|
||||
# 定义转换工具的路径
|
||||
convert_caffe =netrans_path + " import caffe"
|
||||
|
||||
|
@ -20,7 +34,6 @@ def import_caffe_network(name, netrans_path):
|
|||
model_prototxt_path = f"{name}.prototxt"
|
||||
model_caffemodel_path = f"{name}.caffemodel"
|
||||
|
||||
|
||||
# 打印转换信息
|
||||
print(f"=========== Converting {name} Caffe model ===========")
|
||||
|
||||
|
@ -40,9 +53,19 @@ def import_caffe_network(name, netrans_path):
|
|||
|
||||
# 执行转换命令
|
||||
# print(cmd)
|
||||
os.system(cmd)
|
||||
# os.system(cmd)
|
||||
return cmd
|
||||
|
||||
def import_tensorflow_network(name, netrans_path):
|
||||
"""导入 tensorflow 模型
|
||||
|
||||
Args:
|
||||
name (str): 模型名字
|
||||
netrans_path (str): 模型路径
|
||||
|
||||
Returns:
|
||||
cmd (str): 生成的pnnacc 命令行, 被subprocesses执行
|
||||
"""
|
||||
# 定义转换工具的命令
|
||||
convertf_cmd = f"{netrans_path} import tensorflow"
|
||||
|
||||
|
@ -62,12 +85,23 @@ def import_tensorflow_network(name, netrans_path):
|
|||
|
||||
# 执行转换命令
|
||||
# print(cmd)
|
||||
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
||||
return cmd
|
||||
|
||||
# result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
||||
|
||||
# 检查执行结果
|
||||
check_status(result)
|
||||
# check_status(result)
|
||||
|
||||
def import_onnx_network(name, netrans_path):
|
||||
"""导入 onnx 模型
|
||||
|
||||
Args:
|
||||
name (str): 模型名字
|
||||
netrans_path (str): 模型路径
|
||||
|
||||
Returns:
|
||||
cmd (str): 生成的pnnacc 命令行, 被subprocesses执行
|
||||
"""
|
||||
# 定义转换工具的命令
|
||||
convert_onnx_cmd = f"{netrans_path} import onnx"
|
||||
|
||||
|
@ -77,6 +111,7 @@ def import_onnx_network(name, netrans_path):
|
|||
output_path = os.path.join(os.getcwd(), name+"_outputs.txt")
|
||||
with open(output_path, 'r', encoding='utf-8') as file:
|
||||
outputs = str(file.readline().strip())
|
||||
|
||||
cmd = f"{convert_onnx_cmd} \
|
||||
--model {name}.onnx \
|
||||
--output-model {name}.json \
|
||||
|
@ -91,13 +126,24 @@ def import_onnx_network(name, netrans_path):
|
|||
|
||||
# 执行转换命令
|
||||
# print(cmd)
|
||||
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
||||
return cmd
|
||||
|
||||
# result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
||||
|
||||
# 检查执行结果
|
||||
check_status(result)
|
||||
# check_status(result)
|
||||
|
||||
####### TFLITE
|
||||
def import_tflite_network(name, netrans_path):
|
||||
"""导入 tflite 模型
|
||||
|
||||
Args:
|
||||
name (str): 模型名字
|
||||
netrans_path (str): 模型路径
|
||||
|
||||
Returns:
|
||||
cmd (str): 生成的pnnacc 命令行, 被subprocesses执行
|
||||
"""
|
||||
# 定义转换工具的路径或命令
|
||||
convert_tflite = f"{netrans_path} import tflite"
|
||||
|
||||
|
@ -117,13 +163,24 @@ def import_tflite_network(name, netrans_path):
|
|||
|
||||
# 执行转换命令
|
||||
# print(cmd)
|
||||
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
||||
return cmd
|
||||
|
||||
# result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
||||
|
||||
# 检查执行结果
|
||||
check_status(result)
|
||||
# check_status(result)
|
||||
|
||||
|
||||
def import_darknet_network(name, netrans_path):
|
||||
"""导入 darknet 模型
|
||||
|
||||
Args:
|
||||
name (str): 模型名字
|
||||
netrans_path (str): 模型路径
|
||||
|
||||
Returns:
|
||||
cmd (str): 生成的pnnacc 命令行, 被subprocesses执行
|
||||
"""
|
||||
# 定义转换工具的命令
|
||||
convert_darknet_cmd = f"{netrans_path} import darknet"
|
||||
|
||||
|
@ -139,12 +196,23 @@ def import_darknet_network(name, netrans_path):
|
|||
|
||||
# 执行转换命令
|
||||
# print(cmd)
|
||||
return cmd
|
||||
|
||||
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
||||
|
||||
# 检查执行结果
|
||||
check_status(result)
|
||||
|
||||
def import_pytorch_network(name, netrans_path):
|
||||
"""导入 pytorch 模型
|
||||
|
||||
Args:
|
||||
name (str): 模型名字
|
||||
netrans_path (str): 模型路径
|
||||
|
||||
Returns:
|
||||
cmd (str): 生成的pnnacc 命令行, 被subprocesses执行
|
||||
"""
|
||||
# 定义转换工具的命令
|
||||
convert_pytorch_cmd = f"{netrans_path} import pytorch"
|
||||
|
||||
|
@ -168,6 +236,8 @@ def import_pytorch_network(name, netrans_path):
|
|||
|
||||
# 执行转换命令
|
||||
# print(cmd)
|
||||
return cmd
|
||||
|
||||
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
||||
|
||||
# 检查执行结果
|
||||
|
@ -176,12 +246,29 @@ def import_pytorch_network(name, netrans_path):
|
|||
# 使用示例
|
||||
# import_tensorflow_network('model_name', '/path/to/NETRANS_PATH')
|
||||
class ImportModel(AttributeCopier):
|
||||
def __init__(self, source_obj) -> None:
|
||||
"""从实例化的 Netrans 中解析模型参数,并基于 pnnacc 导入模型
|
||||
|
||||
Args:
|
||||
Netrans (class): 实例化的Netrans类,包含 模型信息 和 Netrans 信息
|
||||
"""
|
||||
def __init__(self, source_obj) -> None:
|
||||
"""从实例化的 Netrans 中解析模型参数
|
||||
|
||||
Args:
|
||||
source_obj (class): 实例化的Netrans类,包含 模型信息 和 Netrans 信息
|
||||
|
||||
"""
|
||||
super().__init__(source_obj)
|
||||
# print(source_obj.__dict__)
|
||||
|
||||
@check_path
|
||||
def import_network(self):
|
||||
"""基于 pnnacc 导入模型
|
||||
|
||||
Raises:
|
||||
FileExistsError: 如果不存在模型文件则会报错 FileExistsError
|
||||
RuntimeError: 如果执行导入失败则会报 RuntimeError
|
||||
"""
|
||||
if self.verbose is True :
|
||||
print("begin load model")
|
||||
# print(self.model_path)
|
||||
|
@ -190,36 +277,40 @@ class ImportModel(AttributeCopier):
|
|||
name = self.model_name
|
||||
netrans_path = self.netrans
|
||||
if os.path.isfile(f"{name}.prototxt"):
|
||||
import_caffe_network(name, netrans_path)
|
||||
cmd = import_caffe_network(name, netrans_path)
|
||||
elif os.path.isfile(f"{name}.pb"):
|
||||
import_tensorflow_network(name, netrans_path)
|
||||
cmd = import_tensorflow_network(name, netrans_path)
|
||||
elif os.path.isfile(f"{name}.onnx"):
|
||||
import_onnx_network(name, netrans_path)
|
||||
cmd = import_onnx_network(name, netrans_path)
|
||||
elif os.path.isfile(f"{name}.tflite"):
|
||||
import_tflite_network(name, netrans_path)
|
||||
cmd = import_tflite_network(name, netrans_path)
|
||||
elif os.path.isfile(f"{name}.weights"):
|
||||
import_darknet_network(name, netrans_path)
|
||||
cmd = import_darknet_network(name, netrans_path)
|
||||
elif os.path.isfile(f"{name}.pt"):
|
||||
import_pytorch_network(name, netrans_path)
|
||||
cmd = import_pytorch_network(name, netrans_path)
|
||||
else :
|
||||
# print(os.getcwd())
|
||||
print("=========== can not find suitable model files ===========")
|
||||
sys.exit(-3)
|
||||
# os.chdir("..")
|
||||
raise FileExistsError("Can not find suitable model files")
|
||||
try :
|
||||
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
||||
except :
|
||||
raise RuntimeError("load model failed")
|
||||
# 检查执行结果
|
||||
check_status(result)
|
||||
# os.chdir("..")
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) != 2 :
|
||||
print("Input a network")
|
||||
sys.exit(-1)
|
||||
# def main():
|
||||
# if len(sys.argv) != 2 :
|
||||
# print("Input a network")
|
||||
# sys.exit(-1)
|
||||
|
||||
network_name = sys.argv[1]
|
||||
# check_env(network_name)
|
||||
# network_name = sys.argv[1]
|
||||
# # check_env(network_name)
|
||||
|
||||
netrans_path = os.environ['NETRANS_PATH']
|
||||
# netrans = os.path.join(netrans_path, 'pnnacc')
|
||||
clas = create_cls(netrans_path, network_name,verbose=False)
|
||||
func = ImportModel(clas)
|
||||
func.import_network()
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# netrans_path = os.environ['NETRANS_PATH']
|
||||
# # netrans = os.path.join(netrans_path, 'pnnacc')
|
||||
# clas = create_cls(netrans_path, network_name,verbose=False)
|
||||
# func = ImportModel(clas)
|
||||
# func.import_network()
|
||||
# if __name__ == "__main__":
|
||||
# main()
|
||||
|
|
|
@ -1,95 +0,0 @@
|
|||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
from utils import check_path, AttributeCopier, create_cls
|
||||
|
||||
class Infer(AttributeCopier):
|
||||
def __init__(self, source_obj) -> None:
|
||||
super().__init__(source_obj)
|
||||
|
||||
@check_path
|
||||
def inference_network(self):
|
||||
netrans = self.netrans
|
||||
quantized = self.quantize_type
|
||||
name = self.model_name
|
||||
# print(self.__dict__)
|
||||
|
||||
netrans += " inference"
|
||||
# 进入模型目录
|
||||
|
||||
# 定义类型和量化类型
|
||||
if quantized == 'float':
|
||||
type_ = 'float32'
|
||||
quantization_type = 'float32'
|
||||
elif quantized == 'uint8':
|
||||
quantization_type = 'asymmetric_affine'
|
||||
type_ = 'quantized'
|
||||
elif quantized == 'int8':
|
||||
quantization_type = 'dynamic_fixed_point-8'
|
||||
type_ = 'quantized'
|
||||
elif quantized == 'int16':
|
||||
quantization_type = 'dynamic_fixed_point-16'
|
||||
type_ = 'quantized'
|
||||
else:
|
||||
print("=========== wrong quantization_type ! ( float / uint8 / int8 / int16 )===========")
|
||||
sys.exit(-1)
|
||||
|
||||
# 构建推理命令
|
||||
inf_path = './inf'
|
||||
cmd = f"{netrans} \
|
||||
--dtype {type_} \
|
||||
--batch-size 1 \
|
||||
--model-quantize {name}_{quantization_type}.quantize \
|
||||
--model {name}.json \
|
||||
--model-data {name}.data \
|
||||
--output-dir {inf_path} \
|
||||
--with-input-meta {name}_inputmeta.yml \
|
||||
--device CPU"
|
||||
|
||||
# 执行推理命令
|
||||
if self.verbose is True:
|
||||
print(cmd)
|
||||
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
||||
|
||||
# 检查执行结果
|
||||
if result.returncode == 0:
|
||||
print("\033[32m SUCCESS \033[0m")
|
||||
else:
|
||||
print(f"\033[31m ERROR: {result.stderr} \033[0m")
|
||||
|
||||
# 返回原始目录
|
||||
|
||||
def main():
|
||||
# 检查命令行参数数量
|
||||
if len(sys.argv) < 3:
|
||||
print("Input a network name and quantized type ( float / uint8 / int8 / int16 )")
|
||||
sys.exit(-1)
|
||||
|
||||
# 检查网络目录是否存在
|
||||
network_name = sys.argv[1]
|
||||
if not os.path.exists(network_name):
|
||||
print(f"Directory {network_name} does not exist !")
|
||||
sys.exit(-2)
|
||||
# print("here")
|
||||
# 定义 netrans 路径
|
||||
# netrans = os.path.join(os.environ['NETRANS_PATH'], 'pnnacc')
|
||||
network_name = sys.argv[1]
|
||||
# check_env(network_name)
|
||||
|
||||
netrans_path = os.environ['NETRANS_PATH']
|
||||
# netrans = os.path.join(netrans_path, 'pnnacc')
|
||||
quantize_type = sys.argv[2]
|
||||
cla = create_cls(netrans_path, network_name,quantize_type,False)
|
||||
|
||||
# 调用量化函数
|
||||
func = Infer(cla)
|
||||
func.inference_network()
|
||||
|
||||
# 定义数据集文件路径
|
||||
# dataset_path = './dataset.txt'
|
||||
# 调用推理函数
|
||||
# inference_network(network_name, sys.argv[2])
|
||||
|
||||
if __name__ == '__main__':
|
||||
# print("main")
|
||||
main()
|
|
@ -1,6 +1,7 @@
|
|||
import sys, os
|
||||
import subprocess
|
||||
# import yaml
|
||||
import warnings
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
from ruamel import yaml
|
||||
import file_model
|
||||
|
@ -8,139 +9,256 @@ from import_model import ImportModel
|
|||
from quantize import Quantize
|
||||
from export import Export
|
||||
from config import Config
|
||||
# from utils import check_path
|
||||
import warnings
|
||||
from utils import check_path
|
||||
|
||||
# 忽略 ruamel.yaml 的安全加载警告
|
||||
warnings.simplefilter('ignore', yaml.error.UnsafeLoaderWarning)
|
||||
class Netrans():
|
||||
"""Netrans Python API,用于模型转换和量化操作。
|
||||
|
||||
提供模型加载、配置、量化和导出等功能。
|
||||
"""
|
||||
|
||||
def __init__(self, model_path, netrans=None, verbose=False):
|
||||
self.verbose = verbose
|
||||
self.model_path = os.path.abspath(model_path)
|
||||
self.set_netrans(netrans)
|
||||
_, self.model_name = os.path.split(self.model_path)
|
||||
# self.model_name,_ = os.path.splitext(self.model_name)
|
||||
|
||||
"""
|
||||
初始化Netrans
|
||||
|
||||
Args:
|
||||
model_path (str) : 要进行编译转换的模型工程目录.
|
||||
netrans (str) : 在没有安装 Netrans 的情况下,指定 Netrans 路径。默认为 None。
|
||||
verbose (bool, optional): 是否启用详细模式。默认为 False。
|
||||
|
||||
Returns :
|
||||
None
|
||||
"""
|
||||
self.verbose = verbose
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"Directory not found: {model_path}")
|
||||
self.model_path = os.path.abspath(model_path)
|
||||
self.model_name = os.path.basename(self.model_path)
|
||||
self.set_netrans(netrans)
|
||||
|
||||
"""
|
||||
pipe line
|
||||
"""
|
||||
def model2nbg(self, quantize_type, inputmeta=False, **kargs):
|
||||
"""
|
||||
模型快速转换成NBG
|
||||
|
||||
Args:
|
||||
quantize_type (_type_): 量化类型,支持 uint8, int8, int16。
|
||||
inputmeta (bool, optional): 是否进行参数配置。默认为 False。
|
||||
**kwargs: 其他可选参数。
|
||||
"""
|
||||
self.load()
|
||||
self.config(inputmeta, **kargs)
|
||||
self.quantize(quantize_type, **kargs)
|
||||
self.export(**kargs)
|
||||
|
||||
"""
|
||||
set netrans
|
||||
"""
|
||||
|
||||
def get_os_netrans_path(self):
|
||||
# print(os.environ.get('NETRANS_PATH'))
|
||||
"""
|
||||
获取系统环境变量中的 NETRANS_PATH。
|
||||
|
||||
Returns:
|
||||
str: 如果存在 NETRANS_PATH,则返回路径;否则返回 None
|
||||
"""
|
||||
return os.environ.get('NETRANS_PATH')
|
||||
|
||||
def check_netarans(self):
|
||||
res = subprocess.run([self.netrans], text=True)
|
||||
if res.returncode != 0:
|
||||
print("pleace check the netrans")
|
||||
# return False
|
||||
sys.exit()
|
||||
else :
|
||||
return
|
||||
|
||||
def set_netrans(self, netrans_path=None):
|
||||
"""
|
||||
设置 Netrans 路径。
|
||||
|
||||
Args:
|
||||
netrans_path (str, optional): 如果未设置环境变量 NETRANS_PATH,则可以通过此参数指定。
|
||||
"""
|
||||
if netrans_path is not None :
|
||||
netrans_path = os.path.abspath(netrans_path)
|
||||
else :
|
||||
netrans_path = self.get_os_netrans_path()
|
||||
# print(netrans_path)
|
||||
if os.path.exists(netrans_path):
|
||||
self.netrans = os.path.join(netrans_path, 'pnnacc')
|
||||
self.netrans_path = netrans_path
|
||||
else :
|
||||
print('NETRANS_PATH NOT BEEN SETTED')
|
||||
"""
|
||||
edit config
|
||||
"""
|
||||
# @check_path
|
||||
def config(self, inputmeta=False, **kargs):
|
||||
if not os.path.exists(netrans_path):
|
||||
raise FileExistsError('未找到 Netrans 路径,请设置 NETRANS_PATH 或指定 netrans_path 参数')
|
||||
self.netrans = os.path.join(netrans_path, 'pnnacc')
|
||||
self.netrans_path = netrans_path
|
||||
|
||||
def config(self, inputmeta=False, **kwargs):
|
||||
"""
|
||||
配置模型转换参数
|
||||
|
||||
Args:
|
||||
inputmeta (bool or str, optional): 是否更新模型转换配置参数。
|
||||
- 如果为 False,则自动生成配置文件。
|
||||
- 如果为字符串,则直接使用指定的配置文件路径。
|
||||
**kwargs: 其他可选参数,如 mean、scale、reverse_channel 等。
|
||||
Raises:
|
||||
FileNotFoundError: 没有找到指定的模型转换配置文件,请重新生成
|
||||
FileExistsError: 没有找到指定的模型转换配置文件,请重新生成
|
||||
"""
|
||||
self.input_meta = os.path.join(self.model_path,'%s%s'%(self.model_name, file_model.extensions.input_meta))
|
||||
if isinstance(inputmeta, str):
|
||||
self.input_meta = inputmeta
|
||||
elif isinstance(inputmeta, bool):
|
||||
self.input_meta = os.path.join(self.model_path,'%s%s'%(self.model_name, file_model.extensions.input_meta))
|
||||
if inputmeta is False : self.inputmeta_gen()
|
||||
if inputmeta is False :
|
||||
self.inputmeta_gen()
|
||||
else :
|
||||
sys.exit("check inputmeta file")
|
||||
raise ValueError("inputmeta 参数无效,请设置为 False 或指定配置文件路径")
|
||||
if not os.path.exists(self.input_meta):
|
||||
raise FileExistsError(f"未找到配置文件: {self.input_meta}")
|
||||
if kwargs:
|
||||
self.update_config(**kwargs)
|
||||
# if len(kargs) == 0 : return
|
||||
# if kargs['mean']==0 and kargs['scale'] ==1 : return
|
||||
# if isinstance(kargs['mean'], list) or isinstance(kargs['scale'], (int, float)) or isinstance(kargs['reverse_channel'], bool):
|
||||
# with open(self.input_meta,'r') as f :
|
||||
# yaml = YAML()
|
||||
# data = yaml.load(f)
|
||||
# data = self.upload_cfg(data ,**kargs)
|
||||
# with open(self.input_meta,'w') as f :
|
||||
# yaml = YAML()
|
||||
# yaml.dump(data, f)
|
||||
|
||||
if len(kargs) == 0 : return
|
||||
if kargs['mean']==0 and kargs['scale'] ==1 : return
|
||||
if isinstance(kargs['mean'], list) or isinstance(kargs['scale'], (int, float)) or isinstance(kargs['reverse_channel'], bool):
|
||||
with open(self.input_meta,'r') as f :
|
||||
yaml = YAML()
|
||||
data = yaml.load(f)
|
||||
data = self.upload_cfg(data,**kargs)
|
||||
with open(self.input_meta,'w') as f :
|
||||
yaml = YAML()
|
||||
yaml.dump(data, f)
|
||||
def update_config(self, **kwargs):
|
||||
"""
|
||||
更新配置文件中的参数。
|
||||
|
||||
def upload_cfg(self, data, channel=3, **kargs):
|
||||
grey = config['input_meta']['databases'][0]['ports'][0]['preprocess']['preproc_node_params'] == 'IMAGE_GRAY'
|
||||
if kargs.get('mean') is not None:
|
||||
mean = handel_param(kargs['mean'],grey)
|
||||
self.upload_cfg_mean(data, mean)
|
||||
if kargs.get('scale') is not None:
|
||||
scale = handel_param(kargs['scale'],grey)
|
||||
self.upload_cfg_scale(data, scale)
|
||||
if kargs.get('reverse_channel') is not None:
|
||||
if isinstance(kargs['reverse_channel'],bool):
|
||||
self.upload_cfg_reverse_channel(data, kargs['reverse_channel'])
|
||||
Args:
|
||||
kwargs (dict): 包含需要更新的参数,如 mean、scale、reverse_channel 等。
|
||||
"""
|
||||
with open(self.input_meta, 'r') as f:
|
||||
yaml = YAML()
|
||||
data = yaml.load(f)
|
||||
data = self.upload_cfg(data, **kwargs)
|
||||
with open(self.input_meta, 'w') as f:
|
||||
yaml.dump(data, f)
|
||||
|
||||
|
||||
def upload_cfg(self, data, **kwargs):
|
||||
"""
|
||||
更新配置文件中的参数。
|
||||
|
||||
Args:
|
||||
data (dict): 加载的配置文件内容。
|
||||
**kwargs: 需要更新的参数。
|
||||
"""
|
||||
grey = data['input_meta']['databases'][0]['ports'][0]['preprocess']['preproc_node_params'] == 'IMAGE_GRAY'
|
||||
if 'mean' in kwargs:
|
||||
mean = self.handle_param(kwargs['mean'], grey)
|
||||
data = self.upload_cfg_mean(data, mean)
|
||||
if 'scale' in kwargs:
|
||||
scale = self.handle_param(kwargs['scale'], grey)
|
||||
data = self.upload_cfg_scale(data, scale)
|
||||
if 'reverse_channel' in kwargs:
|
||||
data = self.upload_cfg_reverse_channel(data, kwargs['reverse_channel'])
|
||||
return data
|
||||
|
||||
|
||||
def upload_cfg_mean(self, data, mean):
|
||||
"""
|
||||
更新配置文件中的mean值
|
||||
|
||||
Args:
|
||||
data (yaml): yaml.load 加载的配置文件
|
||||
mean (list): 需要更新的mean值
|
||||
"""
|
||||
for db in data['input_meta']['databases']:
|
||||
db['ports'][0]['preprocess']['mean'] = mean
|
||||
return data
|
||||
def upload_cfg_scale(self, data, scale):
|
||||
"""
|
||||
scale
|
||||
|
||||
Args:
|
||||
data (yaml): yaml.load 加载的配置文件
|
||||
scale (list): 需要更新的 scale 值
|
||||
"""
|
||||
for db in data['input_meta']['databases']:
|
||||
db['ports'][0]['preprocess']['scale'] = scale
|
||||
return data
|
||||
|
||||
def upload_cfg_reverse_channel(self, data, reverse_channel):
|
||||
"""
|
||||
更新配置文件中的reverse_channel
|
||||
|
||||
Args:
|
||||
data (yaml): yaml.load 加载的配置文件
|
||||
reverse_channel (bool): 需要更新的reverse_channel
|
||||
"""
|
||||
for db in data['input_meta']['databases']:
|
||||
db['ports'][0]['preprocess']['reverse_channel'] = reverse_channel
|
||||
return data
|
||||
|
||||
def handle_param(self, param, grey=False):
|
||||
"""
|
||||
处理参数,根据图像类型调整参数格式。
|
||||
|
||||
Args:
|
||||
param: 参数值,可以是单个值或列表。
|
||||
grey (bool, optional): 是否为灰度图像。默认为 False。
|
||||
|
||||
Returns:
|
||||
list: 处理后的参数值。
|
||||
"""
|
||||
if grey:
|
||||
return param
|
||||
return param if isinstance(param, list) else [param] * 3
|
||||
|
||||
def read_input_meta_data(self):
|
||||
"""单元测试中用于判断是否成功修改配置文件中的参数
|
||||
|
||||
Returns:
|
||||
dict : 获取配置文件中的参数
|
||||
"""
|
||||
with open(self.input_meta,'r') as f :
|
||||
yaml = YAML()
|
||||
data = yaml.load(f)
|
||||
res = {}
|
||||
for db in data['input_meta']['databases']:
|
||||
res['scale'] = db['ports'][0]['preprocess']['scale']
|
||||
res['mean'] = db['ports'][0]['preprocess']['mean']
|
||||
res['reverse_channel'] = db['ports'][0]['preprocess']['reverse_channel']
|
||||
return res
|
||||
|
||||
|
||||
def load(self):
|
||||
"""
|
||||
加载模型
|
||||
"""
|
||||
func = ImportModel(self)
|
||||
func.import_network()
|
||||
|
||||
def inputmeta_gen(self):
|
||||
"""
|
||||
自动生成配置文件
|
||||
"""
|
||||
func = Config(self)
|
||||
func.inputmeta_gen()
|
||||
|
||||
def quantize(self, quantize_type,**kargs):
|
||||
"""
|
||||
量化模型
|
||||
|
||||
Args:
|
||||
quantize_type (_type_): 量化类型,支持 uint8, int8, int16
|
||||
|
||||
Raises:
|
||||
TypeError: 仅支持量化成 uint8, int8, int16
|
||||
"""
|
||||
if quantize_type not in ['unit8', 'int8', 'int16']:
|
||||
raise TypeError(f"不支持的量化类型: {quantize_type},仅支持 uint8, int8, int16")
|
||||
self.quantize_type = quantize_type
|
||||
func = Quantize(self)
|
||||
func.quantize_network()
|
||||
Quantize(self).quantize_network()
|
||||
|
||||
def export(self, **kargs):
|
||||
if kargs.get('quantize_type') :
|
||||
self.quantize_type = kargs['quantize_type']
|
||||
if kargs.get('profile') :
|
||||
self.profile = kargs['profile']
|
||||
else :
|
||||
def export(self, **kwargs):
|
||||
"""模型导出
|
||||
"""
|
||||
if 'quantize_type' in kwargs:
|
||||
self.quantize_type = kwargs['quantize_type']
|
||||
if 'profile' in kwargs:
|
||||
self.profile = kwargs['profile']
|
||||
else:
|
||||
self.profile = False
|
||||
func = Export(self)
|
||||
func.export_network()
|
||||
|
||||
|
||||
|
||||
def handel_param(param, grey=False):
|
||||
if grey : return param
|
||||
else :
|
||||
return param if isinstance(param, list) else [param]*3
|
||||
|
||||
Export(self).export_network()
|
||||
|
||||
# 示例用法
|
||||
if __name__ == '__main__':
|
||||
network = '../../model_zoo/yolov4_tiny'
|
||||
yolo = Netrans(network)
|
||||
yolo.inputmeta_gen()
|
||||
# yolo.model2nb("uint8")
|
||||
# yolo.load()
|
||||
# yolo.config(mean=[0,0,0],scale=1)
|
||||
# yolo.quantize('uint8')
|
||||
# yolo.export()
|
||||
yolo.model2nbg("uint8")
|
|
@ -3,11 +3,23 @@ import sys
|
|||
from utils import check_path, AttributeCopier, create_cls
|
||||
|
||||
class Quantize(AttributeCopier):
|
||||
"""
|
||||
解析 Netrans 参数,基于 pnnacc 量化模型
|
||||
Args:
|
||||
cla (class): 实例化以后的 Netrans 类,需要解析里面包含的参数
|
||||
"""
|
||||
def __init__(self, source_obj) -> None:
|
||||
"""
|
||||
从 Netrans 类中获取模型信息
|
||||
Args:
|
||||
source_obj (class): 实例化以后的 Netrans 类,需要解析里面包含的参数
|
||||
"""
|
||||
super().__init__(source_obj)
|
||||
|
||||
@check_path
|
||||
def quantize_network(self):
|
||||
"""基于 pnnacc 量化模型
|
||||
"""
|
||||
netrans = self.netrans
|
||||
quantized_type = self.quantize_type
|
||||
name = self.model_name
|
||||
|
@ -66,28 +78,28 @@ class Quantize(AttributeCopier):
|
|||
print("\033[31m ERROR ! \033[0m")
|
||||
|
||||
|
||||
def main():
|
||||
# 检查命令行参数数量
|
||||
if len(sys.argv) < 3:
|
||||
print("Input a network name and quantized type ( uint8 / int8 / int16 )")
|
||||
sys.exit(-1)
|
||||
# def main():
|
||||
# # 检查命令行参数数量
|
||||
# if len(sys.argv) < 3:
|
||||
# print("Input a network name and quantized type ( uint8 / int8 / int16 )")
|
||||
# sys.exit(-1)
|
||||
|
||||
# 检查网络目录是否存在
|
||||
network_name = sys.argv[1]
|
||||
# # 检查网络目录是否存在
|
||||
# network_name = sys.argv[1]
|
||||
|
||||
# 定义 netrans 路径
|
||||
# netrans = os.path.join(os.environ['NETRANS_PATH'], 'pnnacc')
|
||||
# network_name = sys.argv[1]
|
||||
# check_env(network_name)
|
||||
# # 定义 netrans 路径
|
||||
# # netrans = os.path.join(os.environ['NETRANS_PATH'], 'pnnacc')
|
||||
# # network_name = sys.argv[1]
|
||||
# # check_env(network_name)
|
||||
|
||||
netrans_path = os.environ['NETRANS_PATH']
|
||||
# netrans = os.path.join(netrans_path, 'pnnacc')
|
||||
quantize_type = sys.argv[2]
|
||||
cla = create_cls(netrans_path, network_name,quantize_type)
|
||||
# netrans_path = os.environ['NETRANS_PATH']
|
||||
# # netrans = os.path.join(netrans_path, 'pnnacc')
|
||||
# quantize_type = sys.argv[2]
|
||||
# cla = create_cls(netrans_path, network_name,quantize_type)
|
||||
|
||||
# 调用量化函数
|
||||
run = Quantize(cla)
|
||||
run.quantize_network()
|
||||
# # 调用量化函数
|
||||
# run = Quantize(cla)
|
||||
# run.quantize_network()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# if __name__ == "__main__":
|
||||
# main()
|
||||
|
|
|
@ -1,91 +0,0 @@
|
|||
import os
|
||||
import sys
|
||||
from utils import check_path, AttributeCopier, create_cls
|
||||
|
||||
class Quantize(AttributeCopier):
|
||||
def __init__(self, source_obj) -> None:
|
||||
super().__init__(source_obj)
|
||||
|
||||
@check_path
|
||||
def quantize_network(self):
|
||||
netrans = self.netrans
|
||||
quantized_type = self.quantize_type
|
||||
name = self.model_name
|
||||
# check_env(name)
|
||||
# print(os.getcwd())
|
||||
netrans += " quantize"
|
||||
# 根据量化类型设置量化参数
|
||||
if quantized_type == 'float':
|
||||
print("=========== do not need quantized===========")
|
||||
return
|
||||
elif quantized_type == 'uint8':
|
||||
quantization_type = "asymmetric_affine"
|
||||
elif quantized_type == 'int8':
|
||||
quantization_type = "dynamic_fixed_point-8"
|
||||
elif quantized_type == 'int16':
|
||||
quantization_type = "dynamic_fixed_point-16"
|
||||
else:
|
||||
print("=========== wrong quantization_type ! ( uint8 / int8 / int16 )===========")
|
||||
return
|
||||
|
||||
# 输出量化信息
|
||||
print(" =======================================================================")
|
||||
print(f" ==== Start Quantizing {name} model with type of {quantization_type} ===")
|
||||
print(" =======================================================================")
|
||||
|
||||
# 移除已存在的量化文件
|
||||
quantize_file = f"{name}_{quantization_type}.quantize"
|
||||
current_directory = os.getcwd()
|
||||
txt_path = current_directory+"/dataset.txt"
|
||||
with open(txt_path, 'r', encoding='utf-8') as file:
|
||||
num_lines = len(file.readlines())
|
||||
|
||||
|
||||
# 构建并执行量化命令
|
||||
cmd = f"{netrans} \
|
||||
--qtype {quantized_type} \
|
||||
--hybrid \
|
||||
--quantizer {quantization_type.split('-')[0]} \
|
||||
--model-quantize {quantize_file} \
|
||||
--model {name}.json \
|
||||
--model-data {name}.data \
|
||||
--with-input-meta {name}_inputmeta.yml \
|
||||
--device CPU \
|
||||
--algorithm kl_divergence \
|
||||
--divergence-nbins 2048 \
|
||||
--iterations {num_lines}"
|
||||
|
||||
os.system(cmd)
|
||||
|
||||
# 检查量化结果
|
||||
if os.path.exists(quantize_file):
|
||||
print("\033[31m QUANTIZED SUCCESS \033[0m")
|
||||
else:
|
||||
print("\033[31m ERROR ! \033[0m")
|
||||
|
||||
|
||||
def main():
|
||||
# 检查命令行参数数量
|
||||
if len(sys.argv) < 3:
|
||||
print("Input a network name and quantized type ( uint8 / int8 / int16 )")
|
||||
sys.exit(-1)
|
||||
|
||||
# 检查网络目录是否存在
|
||||
network_name = sys.argv[1]
|
||||
|
||||
# 定义 netrans 路径
|
||||
# netrans = os.path.join(os.environ['NETRANS_PATH'], 'pnnacc')
|
||||
# network_name = sys.argv[1]
|
||||
# check_env(network_name)
|
||||
|
||||
netrans_path = os.environ['NETRANS_PATH']
|
||||
# netrans = os.path.join(netrans_path, 'pnnacc')
|
||||
quantize_type = sys.argv[2]
|
||||
cla = create_cls(netrans_path, network_name,quantize_type)
|
||||
|
||||
# 调用量化函数
|
||||
run = Quantize(cla)
|
||||
run.quantize_network()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -15,6 +15,9 @@ import os
|
|||
# return decorator
|
||||
|
||||
def check_path(func):
|
||||
""" 装饰器, 确保在工程目录运行 nertans
|
||||
|
||||
"""
|
||||
def wrapper(cla, *args, **kargs):
|
||||
check_netrans(cla.netrans)
|
||||
if os.getcwd() != cla.model_path :
|
||||
|
@ -24,18 +27,39 @@ def check_path(func):
|
|||
|
||||
|
||||
def check_dir(network_name):
|
||||
"""判断工程目录是否存在
|
||||
|
||||
Args:
|
||||
network_name (str): 工程目录路径
|
||||
|
||||
Raises:
|
||||
NotADirectoryError: 没有那个工程目录
|
||||
"""
|
||||
if not os.path.exists(network_name):
|
||||
print(f"Directory {network_name} does not exist !")
|
||||
sys.exit(-1)
|
||||
raise NotADirectoryError(
|
||||
f"Directory not found: {network_name}"
|
||||
)
|
||||
# print(f"Directory {network_name} does not exist !")
|
||||
# sys.exit(-1)
|
||||
os.chdir(network_name)
|
||||
|
||||
def check_netrans(netrans):
|
||||
if 'NETRANS_PATH' not in os.environ :
|
||||
return
|
||||
"""判断 netrans 是否配置成功
|
||||
|
||||
Args:
|
||||
netrans (str, bool): _netrans 路径, 如果没有配置(默认为False)会去环境变量里找
|
||||
|
||||
Raises:
|
||||
NotADirectoryError: 找不到 Netrans 会返回 NotADirectoryError
|
||||
"""
|
||||
if netrans != None and os.path.exists(netrans) is True:
|
||||
return
|
||||
print("Need to set enviroment variable NETRANS_PATH")
|
||||
sys.exit(1)
|
||||
if 'NETRANS_PATH' in os.environ :
|
||||
return
|
||||
raise NotADirectoryError(
|
||||
f"Netrans not found: {netrans}"
|
||||
)
|
||||
|
||||
|
||||
def remove_history_file(name):
|
||||
os.chdir(name)
|
||||
|
@ -52,6 +76,8 @@ def check_env(name):
|
|||
|
||||
|
||||
class AttributeCopier:
|
||||
"""快速解析复制 Netrans 信息
|
||||
"""
|
||||
def __init__(self, source_obj) -> None:
|
||||
self.copy_attribute_name(source_obj)
|
||||
|
||||
|
@ -64,6 +90,7 @@ class AttributeCopier:
|
|||
return source_obj.__dict__.keys()
|
||||
|
||||
class create_cls(): #dataclass @netrans_params
|
||||
"""快速测试时候模拟实例化Netrans"""
|
||||
def __init__(self, netrans_path, name, quantized_type = 'uint8',verbose=False) -> None:
|
||||
self.netrans_path = netrans_path
|
||||
self.netrans = os.path.join(self.netrans_path, 'pnnacc')
|
||||
|
@ -72,9 +99,9 @@ class create_cls(): #dataclass @netrans_params
|
|||
self.verbose=verbose
|
||||
self.quantize_type = quantized_type
|
||||
|
||||
if __name__ == "__main__":
|
||||
dir_name = "yolo"
|
||||
os.mkdir(dir_name)
|
||||
check_dir(dir_name)
|
||||
# if __name__ == "__main__":
|
||||
# dir_name = "yolo"
|
||||
# os.mkdir(dir_name)
|
||||
# check_dir(dir_name)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue