diff --git a/netrans_py/README.md b/netrans_py/README.md
new file mode 100644
index 0000000..7796519
--- /dev/null
+++ b/netrans_py/README.md
@@ -0,0 +1,171 @@
+# Python api netrans_py 使用介绍
+
+netrans_py 支持通过 python api 灵活地将模型转换成pnna 支持的格式。
+使用 ntrans_py 完成模型转换的步骤如下:
+1. 导入模型
+2. 生成并修改前处理配置文件 *_inputmeta.yml
+3. 量化模型
+4. 导出模型
+
+## 安装
+在使用netrans_py之前,需要安装netrans_py。
+
+设置环境变量 NETRANS_PATH 并指向该 bin 目录。
+注意: 在该项目中,项目下载目录为 `/home/nudt_dps/netrans`,在您应用的过程中,可以使用 `pwd` 来确认您的项目目录。
+
+```bash
+export NETRANS_PATH=/home/nudt_dps/netrans/bin
+```
+同时设置LD_LIBRARY_PATH(Ubuntu,其他系统根据具体情况设置):
+
+```bash
+export LD_LIBRARY_PATH=/home/nudt_dps/netrans/bin:$LD_LIBRARY_PATH
+```
+注意这一步每次使用前都需要执行,或者您可以写入 .bashrc (路径为 `~/.bashrc` )。
+
+然后进入目录 netrans_py 进行安装。
+```bash
+cd /home/nudt_dps/netrans/netrans_py
+pip3 install -e .
+```
+## netrans_py api
+### Netrans 导入api及创建实例
+创建 Netrans
+
+ 描述: 实例化 Netrans 类。
+ 代码示例:
+
+ ```py3
+ from netrans import Netrans
+ yolo_netrans = Netrans("../examples/darknet/yolov4_tiny")
+ ```
+
+ 参数
+
+| 参数名 | 类型 | 说明 |
+|:---| -- | -- |
+|model_path| str| 第一位置参数,模型文件的路径|
+|netans| str | 如果 NETRANS_PATH 没有设置,可通过该参数指定netrans的路径|
+
+输出返回:
+无。
+
+注意: 模型目录准备需要和netrans_cli一致,具体数据准备要求见[introduction](./introduction.md)。
+
+### Netrans.load_model 模型导入
+
+ 描述: 将模型转换成 pnna 支持的格式。
+ 代码示例:
+
+ ```py3
+ yolo_netrans.load_model()
+ ```
+
+ 参数:
+ 无。
+
+ 输出返回:
+ 无。
+ 在工程目录下生成 pnna 支持的模型格式,以.json结尾的模型文件和 .data结尾的权重文件。
+
+### Netrans.gen_inputmeta 预处理配置文件生成
+
+ 描述: 将模型转换成 pnna 支持的格式。
+ 代码示例:
+
+ ```py3
+ yolo_netrans.gen_inputmeta()
+ ```
+
+ 参数:
+ 无。
+
+ 输出返回:
+ 无。
+
+### Netrans.quantize 量化模型
+
+ 描述: 对模型生成量化配置文件。
+ 代码示例:
+
+ ```py3
+ yolo_netrans.quantize("uint8")
+ ```
+
+ 参数:
+
+| 参数名 | 类型 | 说明 |
+|:---| -- | -- |
+|quantize_type| str| 第一位置参数,模型量化类型,仅支持 "uint8", "int8", "int16"|
+
+ 输出返回:
+ 无。
+
+### Netrans.export 模型导出
+
+ 描述: 对模型生成量化配置文件。
+ 代码示例:
+
+ ```py3
+ yolo_netrans.export()
+ ```
+
+ 参数:
+ 无。
+
+ 输出返回:
+ 无。请在目录 “wksp/*/” 下检查是否生成nbg文件。
+
+### Netrans.model2nbg 一键生成nbg文件
+
+ 描述: 模型导入、量化、及nbg文件生产
+ 代码示例:
+
+ ```py3
+ # 无预处理
+yolo_netrans.model2nbg(quantize_type='uint8')
+ # 需要对数据进行normlize, menas为128, scale 为 0.0039
+yolo_netrans.model2nbg(quantize_type='uint8',mean=128, scale = 0.0039)
+ # 需要对数据分通道进行normlize, menas为128,127,125,scale 为 0.0039, 且reverse_channel 为 True
+yolo_netrans.model2nbg(quantize_type='uint8'mean=[128, 127, 125], scale = 0.0039, reverse_channel= True)
+ # 已经进行初始化设置
+yolo_netrans.model2nbg(quantize_type='uint8', inputmeta=True)
+
+ ```
+
+ 参数
+| 参数名 | 类型 | 说明 |
+|:---| -- | -- |
+|quantize_type| str, ["uint8", "int8", "int16" ] | 量化类型,将模型量化成该参数指定的类型 |
+|inputmeta| bool,str, [Fasle, True, "inputmeta_filepath"] | 指定 inputmeta, 默认为False。
如果为False,则会生成inputmeta模板,可使用mean、scale、reverse_channel 配合修改常用参数。
如果已有现成的 inputmeta 文件,则可通过该参数进行指定,也可使用True, 则会自动索引 model_name_inputmeta.yml |
+|mean| float, int, list | 设置预处理中 normalize 的 mean 参数 |
+|scale| float, int, list | 设置预处理中 normalize 的 scale 参数 |
+|reverse_channel | bool | 设置预处理中的 reverse_channel 参数 |
+
+
+输出返回:
+请在目录 “wksp/*/” 下检查是否生成nbg文件。
+
+## 使用实例
+
+ ```
+from nertans import Netrans
+model_path = 'example/darknet/yolov4_tiny'
+netrans_path = "netrans/bin" # 如果进行了export定义申明,这一步可以不用
+
+# 初始化netrans
+net = Netrans(model_path,netrans=netrans_path)
+# 模型载入
+net.load_model()
+# 生成 inputmeta 文件
+net.gen_inputmeta()
+# 配置预处理 normlize 的参数
+net.config(scale=1,mean=0)
+# 模型量化
+net.quantize("uint8")
+# 模型导出
+net.export()
+
+# 模型直接量化成 int16 并导出, 直接复用刚配置好的 inputmeta
+net.model2nbg(quantize_type = "int16", inputmeta=True)
+```
diff --git a/netrans_py/example.py b/netrans_py/example.py
new file mode 100644
index 0000000..478d46c
--- /dev/null
+++ b/netrans_py/example.py
@@ -0,0 +1,58 @@
+import argparse
+from netrans import Netrans
+
+def main():
+ # 创建参数解析器
+ parser = argparse.ArgumentParser(
+ description='神经网络模型转换工具',
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter # 自动显示默认值
+ )
+
+ # 必填位置参数
+ parser.add_argument(
+ 'model_path',
+ type=str,
+ help='输入模型路径(必须参数)'
+ )
+
+ # 可选参数组
+ quant_group = parser.add_argument_group('量化参数')
+ quant_group.add_argument(
+ '-q', '--quantize_type',
+ type=str,
+ choices=['uint8', 'int8', 'int16', 'float'],
+ default='uint8',
+ metavar='TYPE',
+ help='量化类型(可选值:%(choices)s)'
+ )
+ quant_group.add_argument(
+ '-m', '--mean',
+ type=int,
+ default=0,
+ help='归一化均值(默认:%(default)s)'
+ )
+ quant_group.add_argument(
+ '-s', '--scale',
+ type=float,
+ default=1.0,
+ help='量化缩放系数(默认:%(default)s)'
+ )
+
+ # 解析参数
+ args = parser.parse_args()
+
+ # 执行模型转换
+ try:
+ model = Netrans(model_path=args.model_path)
+ model.model2nbg(
+ quantize_type=args.quantize_type,
+ mean=args.mean,
+ scale=args.scale
+ )
+ print(f"模型 {args.model_path} 转换成功")
+ except FileNotFoundError:
+ print(f"错误:模型文件 {args.model_path} 不存在")
+ exit(1)
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/netrans_py/export.py b/netrans_py/export.py
new file mode 100644
index 0000000..eb30a9d
--- /dev/null
+++ b/netrans_py/export.py
@@ -0,0 +1,150 @@
+import os
+import sys
+import subprocess
+import shutil
+from utils import check_path, AttributeCopier, creat_cla
+# 检查 NETRANS_PATH 环境变量是否设置
+
+# 定义数据集文件路径
+dataset = 'dataset.txt'
+
+class Export(AttributeCopier):
+ def __init__(self, source_obj) -> None:
+ super().__init__(source_obj)
+
+ @check_path
+ def export_network(self):
+
+ netrans = self.netrans
+ quantized = self.quantize_type
+ name = self.model_name
+ netrans_path = self.netrans_path
+
+ ovxgenerator = netrans + " export ovxlib"
+ # 进入模型目录
+ # os.chdir(name)
+
+ # 根据量化类型设置参数
+ if quantized == 'float':
+ type_ = 'float'
+ quantization_type = 'none_quantized'
+ generate_path = './wksp/none_quantized'
+ elif quantized == 'uint8':
+ type_ = 'quantized'
+ quantization_type = 'asymmetric_affine'
+ generate_path = './wksp/asymmetric_affine'
+ elif quantized == 'int8':
+ type_ = 'quantized'
+ quantization_type = 'dynamic_fixed_point-8'
+ generate_path = './wksp/dynamic_fixed_point-8'
+ elif quantized == 'int16':
+ type_ = 'quantized'
+ quantization_type = 'dynamic_fixed_point-16'
+ generate_path = './wksp/dynamic_fixed_point-16'
+ else:
+ print("=========== wrong quantization_type ! ( float / uint8 / int8 / int16 )===========")
+ sys.exit(1)
+
+ # 创建输出目录
+ os.makedirs(generate_path, exist_ok=True)
+
+ # 构建命令
+ if quantized == 'float':
+ cmd = f"{ovxgenerator} \
+ --model {name}.json \
+ --model-data {name}.data \
+ --dtype {type_} \
+ --pack-nbg-viplite \
+ --optimize 'VIP8000NANOQI_PLUS_PID0XB1'\
+ --target-ide-project 'linux64' \
+ --viv-sdk {netrans_path}/pnna_sdk \
+ --output-path {generate_path}/{name}_{quantization_type}"
+ else:
+ if not os.path.exists(f"{name}_{quantization_type}.quantize"):
+ print(f"\033[31m Can not find {name}_{quantization_type}.quantize \033[0m")
+ sys.exit(1)
+ if not os.path.exists(f"{name}_postprocess_file.yml"):
+ cmd = f"{ovxgenerator} \
+ --model {name}.json \
+ --model-data {name}.data \
+ --dtype {type_} \
+ --pack-nbg-viplite \
+ --optimize 'VIP8000NANOQI_PLUS_PID0XB1'\
+ --viv-sdk {netrans_path}/pnna_sdk \
+ --model-quantize {name}_{quantization_type}.quantize \
+ --with-input-meta {name}_inputmeta.yml \
+ --target-ide-project 'linux64' \
+ --output-path {generate_path}/{quantization_type}"
+ else:
+ cmd = f"{ovxgenerator} \
+ --model {name}.json \
+ --model-data {name}.data \
+ --dtype {type_} \
+ --pack-nbg-viplite \
+ --optimize 'VIP8000NANOQI_PLUS_PID0XB1'\
+ --viv-sdk {netrans_path}/pnna_sdk \
+ --model-quantize {name}_{quantization_type}.quantize \
+ --with-input-meta {name}_inputmeta.yml \
+ --target-ide-project 'linux64' \
+ --postprocess-file {name}_postprocess_file.yml \
+ --output-path {generate_path}/{quantization_type}"
+
+
+ # 执行命令
+ # print(cmd)
+ result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
+
+ # 检查执行结果
+ if result.returncode == 0:
+ print("\033[31m SUCCESS \033[0m")
+ else:
+ print(f"\033[31m ERROR ! {result.stderr} \033[0m")
+
+
+ # temp='wksp/temp'
+ # os.makedirs(temp, exist_ok=True)
+
+ src_ngb = f'{generate_path}_nbg_viplite/network_binary.nb'
+ try :
+ shutil.copy(src_ngb, generate_path)
+ except FileNotFoundError:
+ print(f"Error: {src_ngb} is not found")
+ except Exception as e :
+ print(f"a error occurred : {e}")
+
+ try:
+ shutil.rmtree(f"{generate_path}_nbg_viplite")
+ except:
+ sys.exit()
+
+ # try :
+ # shutil.move(temp, generate_path )
+ # except:
+ # sys.exit()
+ # 返回原始目录
+ # os.chdir('..')
+
+def main():
+# 检查命令行参数数量
+ if len(sys.argv) < 3:
+ print("Input a network name and quantized type ( float / uint8 / int8 / int16 )")
+ sys.exit(1)
+ # 检查网络目录是否存在
+ network_name = sys.argv[1]
+ # check_env(network_name)
+ if not os.path.exists(os.path.exists(network_name)):
+ print(f"Directory {network_name} does not exist !")
+ sys.exit(2)
+
+ netrans_path = os.environ['NETRANS_PATH']
+ # netrans = os.path.join(os.environ['NETRANS_PATH'], 'pnnacc')
+ # 调用导出函数ss
+ cla = creat_cla(netrans_path, network_name, sys.argv[2])
+ func = Export(cla)
+ func.export_network()
+
+ # export_network(netrans, network_name, sys.argv[2])
+
+
+if __name__ == '__main__':
+ main()
diff --git a/netrans_py/file_model.py b/netrans_py/file_model.py
new file mode 100644
index 0000000..5c9a239
--- /dev/null
+++ b/netrans_py/file_model.py
@@ -0,0 +1,49 @@
+__all__ = ['extensions']
+
+class model_extensions:
+ def __init__(self, model, model_data, model_quantize, input_meta, output_meta):
+ self._model = model
+ self._model_data = model_data
+ self._model_quantize = model_quantize
+ self._input_meta = input_meta
+ self._output_meta = output_meta
+
+ @property
+ def model(self):
+ return self._model
+
+ @property
+ def model_data(self):
+ return self._model_data
+
+ @property
+ def model_quantize(self):
+ return self._model_quantize
+
+ @property
+ def input_meta(self):
+ return self._input_meta
+
+ @property
+ def output_meta(self):
+ return self._output_meta
+
+class file_model:
+ def __init__(self,extensions):
+ self._extensions = extensions
+
+ @property
+ def extensions(self):
+ return self._extensions
+
+x_extensions = model_extensions(
+ '.json',
+ '.data',
+ '.quantize',
+ '_inputmeta.yml',
+ '.yml'
+)
+
+_file_model = file_model(x_extensions)
+
+extensions = _file_model.extensions
diff --git a/netrans_py/gen_inputmeta.py b/netrans_py/gen_inputmeta.py
new file mode 100644
index 0000000..68979f9
--- /dev/null
+++ b/netrans_py/gen_inputmeta.py
@@ -0,0 +1,38 @@
+
+import os
+import sys
+from utils import check_path, AttributeCopier, creat_cla
+
+class InputmetaGen(AttributeCopier):
+ def __init__(self, source_obj) -> None:
+ super().__init__(source_obj)
+
+ @check_path
+ def inputmeta_gen(self):
+ netrans_path = self.netrans
+ network_name = self.model_name
+ # 进入网络名称指定的目录
+ # os.chdir(network_name)
+ # check_env(network_name)
+
+ # 执行 pegasus 命令
+ os.system(f"{netrans_path} generate inputmeta --model {network_name}.json --separated-database")
+ # os.chdir("..")
+
+def main():
+ # 检查命令行参数数量是否正确
+ if len(sys.argv) != 2:
+ print("Enter a network name!")
+ sys.exit(2)
+
+ # 检查提供的目录是否存在
+ network_name = sys.argv[1]
+ # 构建 netrans 可执行文件的路径
+ netrans_path =os.getenv('NETRANS_PATH')
+ cla = creat_cla(netrans_path, network_name)
+ func = InputmetaGen(cla)
+ func.inputmeta_gen()
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/netrans_py/import_model.py b/netrans_py/import_model.py
new file mode 100644
index 0000000..340cf57
--- /dev/null
+++ b/netrans_py/import_model.py
@@ -0,0 +1,225 @@
+import os
+import sys
+import subprocess
+from utils import check_path, AttributeCopier, creat_cla
+
+def check_status(result):
+ if result.returncode == 0:
+ print("\033[31m LOAD MODEL SUCCESS \033[0m")
+ else:
+ print(f"\033[31m ERROR: {result.stderr} \033[0m")
+
+
+def import_caffe_network(name, netrans_path):
+ # 定义转换工具的路径
+ convert_caffe =netrans_path + " import caffe"
+
+ # 定义模型文件路径
+ model_json_path = f"{name}.json"
+ model_data_path = f"{name}.data"
+ model_prototxt_path = f"{name}.prototxt"
+ model_caffemodel_path = f"{name}.caffemodel"
+
+
+ # 打印转换信息
+ print(f"=========== Converting {name} Caffe model ===========")
+
+ # 构建转换命令
+ if os.path.isfile(model_caffemodel_path):
+ cmd = f"{convert_caffe} \
+ --model {model_prototxt_path} \
+ --weights {model_caffemodel_path} \
+ --output-model {model_json_path} \
+ --output-data {model_data_path}"
+ else:
+ print("=========== fake Caffe model data file =============")
+ cmd = f"{convert_caffe} \
+ --model {model_prototxt_path} \
+ --output-model {model_json_path} \
+ --output-data {model_data_path}"
+
+ # 执行转换命令
+ # print(cmd)
+ os.system(cmd)
+
+def import_tensorflow_network(name, netrans_path):
+ # 定义转换工具的命令
+ convertf_cmd = f"{netrans_path} import tensorflow"
+
+ # 打印转换信息
+ print(f"=========== Converting {name} Tensorflow model ===========")
+
+ # 读取 inputs_outputs.txt 文件中的参数
+ with open('inputs_outputs.txt', 'r') as f:
+ inputs_outputs_params = f.read().strip()
+
+ # 构建转换命令
+ cmd = f"{convertf_cmd} \
+ --model {name}.pb \
+ --output-data {name}.data \
+ --output-model {name}.json \
+ {inputs_outputs_params}"
+
+ # 执行转换命令
+ # print(cmd)
+ result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
+
+ # 检查执行结果
+ check_status(result)
+
+def import_onnx_network(name, netrans_path):
+ # 定义转换工具的命令
+ convert_onnx_cmd = f"{netrans_path} import onnx"
+
+ # 打印转换信息
+ print(f"=========== Converting {name} ONNX model ===========")
+ if os.path.exists(f"{name}_outputs.txt"):
+ output_path = os.path.join(os.getcwd(), name+"_outputs.txt")
+ with open(output_path, 'r', encoding='utf-8') as file:
+ outputs = str(file.readline().strip())
+ cmd = f"{convert_onnx_cmd} \
+ --model {name}.onnx \
+ --output-model {name}.json \
+ --output-data {name}.data \
+ --outputs '{outputs}'"
+ else:
+ # 构建转换命令
+ cmd = f"{convert_onnx_cmd} \
+ --model {name}.onnx \
+ --output-model {name}.json \
+ --output-data {name}.data"
+
+ # 执行转换命令
+ # print(cmd)
+ result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
+
+ # 检查执行结果
+ check_status(result)
+
+####### TFLITE
+def import_tflite_network(name, netrans_path):
+ # 定义转换工具的路径或命令
+ convert_tflite = f"{netrans_path} import tflite"
+
+ # 定义模型文件路径
+ model_json_path = f"{name}.json"
+ model_data_path = f"{name}.data"
+ model_tflite_path = f"{name}.tflite"
+
+ # 打印转换信息
+ print(f"=========== Converting {name} TFLite model ===========")
+
+ # 构建转换命令
+ cmd = f"{convert_tflite} \
+ --model {model_tflite_path} \
+ --output-model {model_json_path} \
+ --output-data {model_data_path}"
+
+ # 执行转换命令
+ # print(cmd)
+ result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
+
+ # 检查执行结果
+ check_status(result)
+
+
+def import_darknet_network(name, netrans_path):
+ # 定义转换工具的命令
+ convert_darknet_cmd = f"{netrans_path} import darknet"
+
+ # 打印转换信息
+ print(f"=========== Converting {name} darknet model ===========")
+
+ # 构建转换命令
+ cmd = f"{convert_darknet_cmd} \
+ --model {name}.cfg \
+ --weight {name}.weights \
+ --output-model {name}.json \
+ --output-data {name}.data"
+
+ # 执行转换命令
+ # print(cmd)
+ result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
+
+ # 检查执行结果
+ check_status(result)
+
+def import_pytorch_network(name, netrans_path):
+ # 定义转换工具的命令
+ convert_pytorch_cmd = f"{netrans_path} import pytorch"
+
+ # 打印转换信息
+ print(f"=========== Converting {name} pytorch model ===========")
+
+ # 读取 input_size.txt 文件中的参数
+ try:
+ with open('input_size.txt', 'r') as file:
+ input_size_params = ' '.join(file.readlines())
+ except FileNotFoundError:
+ print("Error: input_size.txt not found.")
+ sys.exit(1)
+
+ # 构建转换命令
+ cmd = f"{convert_pytorch_cmd} \
+ --model {name}.pt \
+ --output-model {name}.json \
+ --output-data {name}.data \
+ {input_size_params}"
+
+ # 执行转换命令
+ # print(cmd)
+ result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
+
+ # 检查执行结果
+ check_status(result)
+
+# 使用示例
+# import_tensorflow_network('model_name', '/path/to/NETRANS_PATH')
+class ImportModel(AttributeCopier):
+ def __init__(self, source_obj) -> None:
+ super().__init__(source_obj)
+ # print(source_obj.__dict__)
+
+ @check_path
+ def import_network(self):
+ if self.verbose is True :
+ print("begin load model")
+ # print(self.model_path)
+ print(os.getcwd())
+ print(f"{self.model_name}.weights")
+ name = self.model_name
+ netrans_path = self.netrans
+ if os.path.isfile(f"{name}.prototxt"):
+ import_caffe_network(name, netrans_path)
+ elif os.path.isfile(f"{name}.pb"):
+ import_tensorflow_network(name, netrans_path)
+ elif os.path.isfile(f"{name}.onnx"):
+ import_onnx_network(name, netrans_path)
+ elif os.path.isfile(f"{name}.tflite"):
+ import_tflite_network(name, netrans_path)
+ elif os.path.isfile(f"{name}.weights"):
+ import_darknet_network(name, netrans_path)
+ elif os.path.isfile(f"{name}.pt"):
+ import_pytorch_network(name, netrans_path)
+ else :
+ # print(os.getcwd())
+ print("=========== can not find suitable model files ===========")
+ sys.exit(-3)
+ # os.chdir("..")
+
+
+def main():
+ if len(sys.argv) != 2 :
+ print("Input a network")
+ sys.exit(-1)
+
+ network_name = sys.argv[1]
+ # check_env(network_name)
+
+ netrans_path = os.environ['NETRANS_PATH']
+ # netrans = os.path.join(netrans_path, 'pnnacc')
+ clas = creat_cla(netrans_path, network_name,verbose=False)
+ func = ImportModel(clas)
+ func.import_network()
+if __name__ == "__main__":
+ main()
diff --git a/netrans_py/infer.py b/netrans_py/infer.py
new file mode 100644
index 0000000..d0435e2
--- /dev/null
+++ b/netrans_py/infer.py
@@ -0,0 +1,95 @@
+import os
+import sys
+import subprocess
+from utils import check_path, AttributeCopier, creat_cla
+
+class Infer(AttributeCopier):
+ def __init__(self, source_obj) -> None:
+ super().__init__(source_obj)
+
+ @check_path
+ def inference_network(self):
+ netrans = self.netrans
+ quantized = self.quantize_type
+ name = self.model_name
+ # print(self.__dict__)
+
+ netrans += " inference"
+ # 进入模型目录
+
+ # 定义类型和量化类型
+ if quantized == 'float':
+ type_ = 'float32'
+ quantization_type = 'float32'
+ elif quantized == 'uint8':
+ quantization_type = 'asymmetric_affine'
+ type_ = 'quantized'
+ elif quantized == 'int8':
+ quantization_type = 'dynamic_fixed_point-8'
+ type_ = 'quantized'
+ elif quantized == 'int16':
+ quantization_type = 'dynamic_fixed_point-16'
+ type_ = 'quantized'
+ else:
+ print("=========== wrong quantization_type ! ( float / uint8 / int8 / int16 )===========")
+ sys.exit(-1)
+
+ # 构建推理命令
+ inf_path = './inf'
+ cmd = f"{netrans} \
+ --dtype {type_} \
+ --batch-size 1 \
+ --model-quantize {name}_{quantization_type}.quantize \
+ --model {name}.json \
+ --model-data {name}.data \
+ --output-dir {inf_path} \
+ --with-input-meta {name}_inputmeta.yml \
+ --device CPU"
+
+ # 执行推理命令
+ if self.verbose is True:
+ print(cmd)
+ result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
+
+ # 检查执行结果
+ if result.returncode == 0:
+ print("\033[32m SUCCESS \033[0m")
+ else:
+ print(f"\033[31m ERROR: {result.stderr} \033[0m")
+
+ # 返回原始目录
+
+def main():
+ # 检查命令行参数数量
+ if len(sys.argv) < 3:
+ print("Input a network name and quantized type ( float / uint8 / int8 / int16 )")
+ sys.exit(-1)
+
+ # 检查网络目录是否存在
+ network_name = sys.argv[1]
+ if not os.path.exists(network_name):
+ print(f"Directory {network_name} does not exist !")
+ sys.exit(-2)
+ # print("here")
+ # 定义 netrans 路径
+ # netrans = os.path.join(os.environ['NETRANS_PATH'], 'pnnacc')
+ network_name = sys.argv[1]
+ # check_env(network_name)
+
+ netrans_path = os.environ['NETRANS_PATH']
+ # netrans = os.path.join(netrans_path, 'pnnacc')
+ quantize_type = sys.argv[2]
+ cla = creat_cla(netrans_path, network_name,quantize_type,False)
+
+ # 调用量化函数
+ func = Infer(cla)
+ func.inference_network()
+
+ # 定义数据集文件路径
+ # dataset_path = './dataset.txt'
+ # 调用推理函数
+ # inference_network(network_name, sys.argv[2])
+
+if __name__ == '__main__':
+ # print("main")
+ main()
diff --git a/netrans_py/netrans.py b/netrans_py/netrans.py
new file mode 100644
index 0000000..d42673c
--- /dev/null
+++ b/netrans_py/netrans.py
@@ -0,0 +1,142 @@
+import sys, os
+import subprocess
+# import yaml
+from ruamel.yaml import YAML
+from ruamel import yaml
+import file_model
+from import_model import ImportModel
+from quantize import Quantize
+from export import Export
+from gen_inputmeta import InputmetaGen
+# from utils import check_path
+import warnings
+warnings.simplefilter('ignore', yaml.error.UnsafeLoaderWarning)
+class Netrans():
+
+ def __init__(self, model_path, netrans=None, verbose=False):
+ self.verbose = verbose
+ self.model_path = os.path.abspath(model_path)
+ self.set_netrans(netrans)
+ _, self.model_name = os.path.split(self.model_path)
+ # self.model_name,_ = os.path.splitext(self.model_name)
+
+
+ """
+ pipe line
+ """
+ def model2nbg(self, quantize_type, inputmeta=False, **kargs):
+ self.load_model()
+ self.config(inputmeta, **kargs)
+ self.quantize(quantize_type)
+ self.export()
+
+ """
+ set netrans
+ """
+ def get_os_netrans_path(self):
+ # print(os.environ.get('NETRANS_PATH'))
+ return os.environ.get('NETRANS_PATH')
+
+ def check_netarans(self):
+ res = subprocess.run([self.netrans], text=True)
+ if res.returncode != 0:
+ print("pleace check the netrans")
+ # return False
+ sys.exit()
+ else :
+ return
+
+ def set_netrans(self, netrans_path=None):
+ if netrans_path is not None :
+ netrans_path = os.path.abspath(netrans_path)
+ else :
+ netrans_path = self.get_os_netrans_path()
+ # print(netrans_path)
+ if os.path.exists(netrans_path):
+ self.netrans = os.path.join(netrans_path, 'pnnacc')
+ self.netrans_path = netrans_path
+ else :
+ print('NETRANS_PATH NOT BEEN SETTED')
+ """
+ edit config
+ """
+ # @check_path
+ def config(self, inputmeta=False, **kargs):
+ if isinstance(inputmeta, str):
+ self.input_meta = inputmeta
+ elif isinstance(inputmeta, bool):
+ self.input_meta = os.path.join(self.model_path,'%s%s'%(self.model_name, file_model.extensions.input_meta))
+ if inputmeta is False : self.inputmeta_gen()
+ else :
+ sys.exit("check inputmeta file")
+
+ if len(kargs) == 0 : return
+ if kargs['mean']==0 and kargs['scale'] ==1 : return
+ if isinstance(kargs['mean'], list) or isinstance(kargs['scale'], (int, float)) or isinstance(kargs['reverse_channel'], bool):
+ with open(self.input_meta,'r') as f :
+ yaml = YAML()
+ data = yaml.load(f)
+ data = self.upload_cfg(data,**kargs)
+ with open(self.input_meta,'w') as f :
+ yaml = YAML()
+ yaml.dump(data, f)
+
+ def upload_cfg(self, data, channel=3, **kargs):
+ grey = config['input_meta']['databases'][0]['ports'][0]['preprocess']['preproc_node_params'] == 'IMAGE_GRAY'
+ if kargs.get('mean') is not None:
+ mean = handel_param(kargs['mean'],grey)
+ self.upload_cfg_mean(data, mean)
+ if kargs.get('scale') is not None:
+ scale = handel_param(kargs['scale'],grey)
+ self.upload_cfg_scale(data, scale)
+ if kargs.get('reverse_channel') is not None:
+ if isinstance(kargs['reverse_channel'],bool):
+ self.upload_cfg_reverse_channel(data, kargs['reverse_channel'])
+ return data
+
+ def upload_cfg_mean(self, data, mean):
+ for db in data['input_meta']['databases']:
+ db['ports'][0]['preprocess']['mean'] = mean
+ def upload_cfg_scale(self, data, scale):
+ for db in data['input_meta']['databases']:
+ db['ports'][0]['preprocess']['scale'] = scale
+ def upload_cfg_reverse_channel(self, data, reverse_channel):
+ for db in data['input_meta']['databases']:
+ db['ports'][0]['preprocess']['reverse_channel'] = reverse_channel
+
+ def load_model(self):
+ func = ImportModel(self)
+ func.import_network()
+
+ def inputmeta_gen(self):
+ func = InputmetaGen(self)
+ func.inputmeta_gen()
+
+ def quantize(self, quantize_type):
+ self.quantize_type = quantize_type
+ func = Quantize(self)
+ func.quantize_network()
+
+ def export(self, **kargs):
+ if kargs.get('quantize_type') :
+ self.quantize_type = kargs['quantize_type']
+ func = Export(self)
+ func.export_network()
+
+
+
+def handel_param(param, grey=False):
+ if grey : return param
+ else :
+ return param if isinstance(param, list) else [param]*3
+
+
+if __name__ == '__main__':
+ network = '../../model_zoo/yolov4_tiny'
+ yolo = Netrans(network)
+ yolo.inputmeta_gen()
+ # yolo.model2nb("uint8")
+ # yolo.load_model()
+ # yolo.config(mean=[0,0,0],scale=1)
+ # yolo.quantize('uint8')
+ # yolo.export()
diff --git a/netrans_py/quantize.py b/netrans_py/quantize.py
new file mode 100644
index 0000000..b634d0a
--- /dev/null
+++ b/netrans_py/quantize.py
@@ -0,0 +1,93 @@
+import os
+import sys
+from utils import check_path, AttributeCopier, creat_cla
+
+class Quantize(AttributeCopier):
+ def __init__(self, source_obj) -> None:
+ super().__init__(source_obj)
+
+ @check_path
+ def quantize_network(self):
+ netrans = self.netrans
+ quantized_type = self.quantize_type
+ name = self.model_name
+ # check_env(name)
+ # print(os.getcwd())
+ netrans += " quantize"
+ # 根据量化类型设置量化参数
+ if quantized_type == 'float':
+ print("=========== do not need quantized===========")
+ return
+ elif quantized_type == 'uint8':
+ quantization_type = "asymmetric_affine"
+ elif quantized_type == 'int8':
+ quantization_type = "dynamic_fixed_point-8"
+ elif quantized_type == 'int16':
+ quantization_type = "dynamic_fixed_point-16"
+ else:
+ print("=========== wrong quantization_type ! ( uint8 / int8 / int16 )===========")
+ return
+
+ # 输出量化信息
+ print(" =======================================================================")
+ print(f" ==== Start Quantizing {name} model with type of {quantization_type} ===")
+ print(" =======================================================================")
+ current_directory = os.getcwd()
+ txt_path = current_directory+"/dataset.txt"
+ with open(txt_path, 'r', encoding='utf-8') as file:
+ num_lines = len(file.readlines())
+
+ # 移除已存在的量化文件
+ quantize_file = f"{name}_{quantization_type}.quantize"
+ if os.path.exists(quantize_file):
+ print(f"\033[31m rm {quantize_file} \033[0m")
+ os.remove(quantize_file)
+
+ # 构建并执行量化命令
+ cmd = f"{netrans} \
+ --batch-size 1 \
+ --qtype {quantized_type} \
+ --rebuild \
+ --quantizer {quantization_type.split('-')[0]} \
+ --model-quantize {quantize_file} \
+ --model {name}.json \
+ --model-data {name}.data \
+ --with-input-meta {name}_inputmeta.yml \
+ --device CPU \
+ --algorithm kl_divergence \
+ --iterations {num_lines}"
+
+ os.system(cmd)
+
+ # 检查量化结果
+ if os.path.exists(quantize_file):
+ print("\033[31m QUANTIZED SUCCESS \033[0m")
+ else:
+ print("\033[31m ERROR ! \033[0m")
+
+
+def main():
+ # 检查命令行参数数量
+ if len(sys.argv) < 3:
+ print("Input a network name and quantized type ( uint8 / int8 / int16 )")
+ sys.exit(-1)
+
+ # 检查网络目录是否存在
+ network_name = sys.argv[1]
+
+ # 定义 netrans 路径
+ # netrans = os.path.join(os.environ['NETRANS_PATH'], 'pnnacc')
+ # network_name = sys.argv[1]
+ # check_env(network_name)
+
+ netrans_path = os.environ['NETRANS_PATH']
+ # netrans = os.path.join(netrans_path, 'pnnacc')
+ quantize_type = sys.argv[2]
+ cla = creat_cla(netrans_path, network_name,quantize_type)
+
+ # 调用量化函数
+ run = Quantize(cla)
+ run.quantize_network()
+
+if __name__ == "__main__":
+ main()
diff --git a/netrans_py/quantize_hb.py b/netrans_py/quantize_hb.py
new file mode 100644
index 0000000..78eef56
--- /dev/null
+++ b/netrans_py/quantize_hb.py
@@ -0,0 +1,91 @@
+import os
+import sys
+from utils import check_path, AttributeCopier, creat_cla
+
+class Quantize(AttributeCopier):
+ def __init__(self, source_obj) -> None:
+ super().__init__(source_obj)
+
+ @check_path
+ def quantize_network(self):
+ netrans = self.netrans
+ quantized_type = self.quantize_type
+ name = self.model_name
+ # check_env(name)
+ # print(os.getcwd())
+ netrans += " quantize"
+ # 根据量化类型设置量化参数
+ if quantized_type == 'float':
+ print("=========== do not need quantized===========")
+ return
+ elif quantized_type == 'uint8':
+ quantization_type = "asymmetric_affine"
+ elif quantized_type == 'int8':
+ quantization_type = "dynamic_fixed_point-8"
+ elif quantized_type == 'int16':
+ quantization_type = "dynamic_fixed_point-16"
+ else:
+ print("=========== wrong quantization_type ! ( uint8 / int8 / int16 )===========")
+ return
+
+ # 输出量化信息
+ print(" =======================================================================")
+ print(f" ==== Start Quantizing {name} model with type of {quantization_type} ===")
+ print(" =======================================================================")
+
+ # 移除已存在的量化文件
+ quantize_file = f"{name}_{quantization_type}.quantize"
+ current_directory = os.getcwd()
+ txt_path = current_directory+"/dataset.txt"
+ with open(txt_path, 'r', encoding='utf-8') as file:
+ num_lines = len(file.readlines())
+
+
+ # 构建并执行量化命令
+ cmd = f"{netrans} \
+ --qtype {quantized_type} \
+ --hybrid \
+ --quantizer {quantization_type.split('-')[0]} \
+ --model-quantize {quantize_file} \
+ --model {name}.json \
+ --model-data {name}.data \
+ --with-input-meta {name}_inputmeta.yml \
+ --device CPU \
+ --algorithm kl_divergence \
+ --divergence-nbins 2048 \
+ --iterations {num_lines}"
+
+ os.system(cmd)
+
+ # 检查量化结果
+ if os.path.exists(quantize_file):
+ print("\033[31m QUANTIZED SUCCESS \033[0m")
+ else:
+ print("\033[31m ERROR ! \033[0m")
+
+
+def main():
+ # 检查命令行参数数量
+ if len(sys.argv) < 3:
+ print("Input a network name and quantized type ( uint8 / int8 / int16 )")
+ sys.exit(-1)
+
+ # 检查网络目录是否存在
+ network_name = sys.argv[1]
+
+ # 定义 netrans 路径
+ # netrans = os.path.join(os.environ['NETRANS_PATH'], 'pnnacc')
+ # network_name = sys.argv[1]
+ # check_env(network_name)
+
+ netrans_path = os.environ['NETRANS_PATH']
+ # netrans = os.path.join(netrans_path, 'pnnacc')
+ quantize_type = sys.argv[2]
+ cla = creat_cla(netrans_path, network_name,quantize_type)
+
+ # 调用量化函数
+ run = Quantize(cla)
+ run.quantize_network()
+
+if __name__ == "__main__":
+ main()
diff --git a/netrans_py/setup.py b/netrans_py/setup.py
new file mode 100644
index 0000000..d094cf5
--- /dev/null
+++ b/netrans_py/setup.py
@@ -0,0 +1,16 @@
+from setuptools import setup, find_packages
+
+with open("README.md", "r", encoding="utf-8") as fh:
+ long_description = fh.read()
+
+setup(
+ name="netrans",
+ version="0.1.0",
+ author="nudt_dsp",
+ url="https://gitlink.org.cn/gwg_xujiao/netrans",
+ packages=find_packages(include=["netrans_py"]),
+ package_dir={"": "."}, # 指定根目录映射关系[8](@ref)
+ install_requires=[
+ "ruamel.yaml==0.18.6"
+ ]
+)
diff --git a/netrans_py/utils.py b/netrans_py/utils.py
new file mode 100644
index 0000000..cb29427
--- /dev/null
+++ b/netrans_py/utils.py
@@ -0,0 +1,80 @@
+import sys
+import os
+# from functools import wraps
+
+# def check_path(netrans, model_path):
+# def decorator(func):
+# @wraps(func)
+# def wrapper(netrans, model_path, *args, **kargs):
+# check_dir(model_path)
+# check_netrans(netrans)
+# if os.getcwd() != model_path :
+# os.chdir(model_path)
+# return func(netrans, model_path, *args, **kargs)
+# return wrapper
+# return decorator
+
+def check_path(func):
+ def wrapper(cla, *args, **kargs):
+ check_netrans(cla.netrans)
+ if os.getcwd() != cla.model_path :
+ os.chdir(cla.model_path)
+ return func(cla, *args, **kargs)
+ return wrapper
+
+
+def check_dir(network_name):
+ if not os.path.exists(network_name):
+ print(f"Directory {network_name} does not exist !")
+ sys.exit(-1)
+ os.chdir(network_name)
+
+def check_netrans(netrans):
+ if 'NETRANS_PATH' not in os.environ :
+ return
+ if netrans != None and os.path.exists(netrans) is True:
+ return
+ print("Need to set enviroment variable NETRANS_PATH")
+ sys.exit(1)
+
+def remove_history_file(name):
+ os.chdir(name)
+ if os.path.isfile(f"{name}.json"):
+ os.remove(f"{name}.json")
+ if os.path.isfile(f"{name}.data"):
+ os.remove(f"{name}.data")
+ os.chdir('..')
+
+def check_env(name):
+ check_dir(name)
+# check_netrans()
+ # remove_history_file(name)
+
+
+class AttributeCopier:
+ def __init__(self, source_obj) -> None:
+ self.copy_attribute_name(source_obj)
+
+ def copy_attribute_name(self, source_obj):
+ for attribute_name in self._get_attribute_names(source_obj):
+ setattr(self, attribute_name, getattr(source_obj, attribute_name))
+
+ @staticmethod
+ def _get_attribute_names(source_obj):
+ return source_obj.__dict__.keys()
+
+class creat_cla(): #dataclass @netrans_params
+ def __init__(self, netrans_path, name, quantized_type = 'uint8',verbose=False) -> None:
+ self.netrans_path = netrans_path
+ self.netrans = os.path.join(self.netrans_path, 'pnnacc')
+ self.model_name=self.model_path = name
+ self.model_path = os.path.abspath(self.model_path)
+ self.verbose=verbose
+ self.quantize_type = quantized_type
+
+if __name__ == "__main__":
+ dir_name = "yolo"
+ os.mkdir(dir_name)
+ check_dir(dir_name)
+
+