diff --git a/netrans_py/README.md b/netrans_py/README.md
index 661d345..654ea87 100644
--- a/netrans_py/README.md
+++ b/netrans_py/README.md
@@ -1,15 +1,15 @@
-# Python api netrans_py 使用介绍
+# netrans_py 使用
-netrans_py 支持通过 python api 灵活地将模型转换成pnna 支持的格式。
+netrans_py 为 Netrans 编译器的 python 调用接口。
使用 ntrans_py 完成模型转换的步骤如下:
+
1. 导入模型
2. 生成并修改前处理配置文件 *_inputmeta.yml
3. 量化模型
4. 导出模型
+## Netrans 类
-## netrans_py api
-### Netrans 导入api及创建实例
创建 Netrans
描述: 实例化 Netrans 类。
@@ -21,7 +21,7 @@ netrans_py 支持通过 python api 灵活地将模型转换成pnna 支持的格
```
参数
-
+
| 参数名 | 类型 | 说明 |
|:---| -- | -- |
|model_path| str| 第一位置参数,模型文件的路径|
@@ -30,15 +30,15 @@ netrans_py 支持通过 python api 灵活地将模型转换成pnna 支持的格
输出返回:
无。
-注意: 模型目录准备需要和netrans_cli一致,具体数据准备要求见[introduction](./introduction.md)。
+
-### Netrans.load_model 模型导入
+## Netrans.import 模型导入
- 描述: 将模型转换成 pnna 支持的格式。
+ 描述: 将模型转换成 Pnna 支持的格式。
代码示例:
```py3
- yolo_netrans.load_model()
+ yolo_netrans.import()
```
参数:
@@ -46,11 +46,11 @@ netrans_py 支持通过 python api 灵活地将模型转换成pnna 支持的格
输出返回:
无。
- 在工程目录下生成 pnna 支持的模型格式,以.json结尾的模型文件和 .data结尾的权重文件。
+ 在工程目录下生成 Pnna 支持的模型格式,以.json结尾的模型文件和 .data结尾的权重文件。
-### Netrans.config 预处理配置文件生成
+## Netrans.config 预处理配置文件生成
- 描述: 将模型转换成 pnna 支持的格式。
+ 描述: 将模型转换成 Pnna 支持的格式。
代码示例:
```py3
@@ -58,55 +58,64 @@ netrans_py 支持通过 python api 灵活地将模型转换成pnna 支持的格
```
参数:
+
+```{table}
+:widths: 20, 30, 50
+:align: left
| 参数名 | 类型 | 说明 |
|:---| -- | -- |
|inputmeta| bool,str, [Fasle, True, "inputmeta_filepath"] | 指定 inputmeta, 默认为False。
如果为False,则会生成inputmeta模板,可使用mean、scale、reverse_channel 配合修改常用参数。
如果已有现成的 inputmeta 文件,则可通过该参数进行指定,也可使用True, 则会自动索引 model_name_inputmeta.yml |
|mean| float, int, list | 设置预处理中 normalize 的 mean 参数 |
|scale| float, int, list | 设置预处理中 normalize 的 scale 参数 |
|reverse_channel | bool | 设置预处理中的 reverse_channel 参数 |
+```
输出返回:
无。
-### Netrans.quantize 量化模型
+## Netrans.quantize 模型量化
- 描述: 对模型生成量化配置文件。
- 代码示例:
+描述: 对模型生成量化配置文件。
+代码示例:
- ```py3
- yolo_netrans.quantize("uint8")
- ```
+```py3
+yolo_netrans.quantize("uint8")
+```
- 参数:
-
+参数:
+
+```{table}
+:widths: 20, 30, 50
+:align: left
| 参数名 | 类型 | 说明 |
|:---| -- | -- |
|quantize_type| str| 第一位置参数,模型量化类型,仅支持 "uint8", "int8", "int16"|
+```
- 输出返回:
- 无。
+输出返回:
+ 无。
-### Netrans.export 模型导出
+## Netrans.export 模型导出
- 描述: 对模型生成量化配置文件。
- 代码示例:
+描述: 对模型生成量化配置文件。
+代码示例:
- ```py3
- yolo_netrans.export()
- ```
+```py3
+yolo_netrans.export()
+```
- 参数:
- 无。
+参数:
+ 无。
- 输出返回:
- 无。请在目录 “wksp/*/” 下检查是否生成nbg文件。
+输出返回:
+ 无。请在目录 “wksp/*/” 下检查是否生成nbg文件。
-### Netrans.model2nbg 一键生成nbg文件
+## Netrans.model2nbg 模型生成nbg文件
- 描述: 模型导入、量化、及nbg文件生产
- 代码示例:
+描述: 模型导入、量化、及nbg文件生产
+代码示例:
- ```py3
+```py3
# 无预处理
yolo_netrans.model2nbg(quantize_type='uint8')
# 需要对数据进行normlize, menas为128, scale 为 0.0039
@@ -116,9 +125,13 @@ yolo_netrans.model2nbg(quantize_type='uint8'mean=[128, 127, 125], scale = 0.0039
# 已经进行初始化设置
yolo_netrans.model2nbg(quantize_type='uint8', inputmeta=True)
- ```
+```
- 参数
+参数
+
+```{table}
+:widths: 20, 30, 50
+:align: left
| 参数名 | 类型 | 说明 |
|:---| -- | -- |
|quantize_type| str, ["uint8", "int8", "int16" ] | 量化类型,将模型量化成该参数指定的类型 |
@@ -126,12 +139,12 @@ yolo_netrans.model2nbg(quantize_type='uint8', inputmeta=True)
|mean| float, int, list | 设置预处理中 normalize 的 mean 参数 |
|scale| float, int, list | 设置预处理中 normalize 的 scale 参数 |
|reverse_channel | bool | 设置预处理中的 reverse_channel 参数 |
-
+```
输出返回:
请在目录 “wksp/*/” 下检查是否生成nbg文件。
-## 使用实例
+## 使用示例
```py3
from nertans import Netrans
@@ -141,9 +154,7 @@ netrans_path = "netrans/bin" # 如果进行了export定义申明,这一步可
# 初始化netrans
net = Netrans(model_path,netrans=netrans_path)
# 模型载入
-net.load_model()
-# 生成 inputmeta 文件
-net.gen_inputmeta()
+net.import()
# 配置预处理 normlize 的参数
net.config(scale=1,mean=0)
# 模型量化
diff --git a/netrans_py/example.py b/netrans_py/example.py
deleted file mode 100644
index 478d46c..0000000
--- a/netrans_py/example.py
+++ /dev/null
@@ -1,58 +0,0 @@
-import argparse
-from netrans import Netrans
-
-def main():
- # 创建参数解析器
- parser = argparse.ArgumentParser(
- description='神经网络模型转换工具',
- formatter_class=argparse.ArgumentDefaultsHelpFormatter # 自动显示默认值
- )
-
- # 必填位置参数
- parser.add_argument(
- 'model_path',
- type=str,
- help='输入模型路径(必须参数)'
- )
-
- # 可选参数组
- quant_group = parser.add_argument_group('量化参数')
- quant_group.add_argument(
- '-q', '--quantize_type',
- type=str,
- choices=['uint8', 'int8', 'int16', 'float'],
- default='uint8',
- metavar='TYPE',
- help='量化类型(可选值:%(choices)s)'
- )
- quant_group.add_argument(
- '-m', '--mean',
- type=int,
- default=0,
- help='归一化均值(默认:%(default)s)'
- )
- quant_group.add_argument(
- '-s', '--scale',
- type=float,
- default=1.0,
- help='量化缩放系数(默认:%(default)s)'
- )
-
- # 解析参数
- args = parser.parse_args()
-
- # 执行模型转换
- try:
- model = Netrans(model_path=args.model_path)
- model.model2nbg(
- quantize_type=args.quantize_type,
- mean=args.mean,
- scale=args.scale
- )
- print(f"模型 {args.model_path} 转换成功")
- except FileNotFoundError:
- print(f"错误:模型文件 {args.model_path} 不存在")
- exit(1)
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/netrans_py/example.py b/netrans_py/example.py
new file mode 120000
index 0000000..884c96b
--- /dev/null
+++ b/netrans_py/example.py
@@ -0,0 +1 @@
+../netrans_cli/example.py
\ No newline at end of file
diff --git a/netrans_py/export.py b/netrans_py/export.py
index eb30a9d..a9f98ac 100644
--- a/netrans_py/export.py
+++ b/netrans_py/export.py
@@ -2,7 +2,7 @@ import os
import sys
import subprocess
import shutil
-from utils import check_path, AttributeCopier, creat_cla
+from utils import check_path, AttributeCopier, create_cls
# 检查 NETRANS_PATH 环境变量是否设置
# 定义数据集文件路径
@@ -63,35 +63,33 @@ class Export(AttributeCopier):
if not os.path.exists(f"{name}_{quantization_type}.quantize"):
print(f"\033[31m Can not find {name}_{quantization_type}.quantize \033[0m")
sys.exit(1)
- if not os.path.exists(f"{name}_postprocess_file.yml"):
- cmd = f"{ovxgenerator} \
- --model {name}.json \
- --model-data {name}.data \
- --dtype {type_} \
- --pack-nbg-viplite \
- --optimize 'VIP8000NANOQI_PLUS_PID0XB1'\
- --viv-sdk {netrans_path}/pnna_sdk \
- --model-quantize {name}_{quantization_type}.quantize \
- --with-input-meta {name}_inputmeta.yml \
- --target-ide-project 'linux64' \
- --output-path {generate_path}/{quantization_type}"
- else:
- cmd = f"{ovxgenerator} \
- --model {name}.json \
- --model-data {name}.data \
- --dtype {type_} \
- --pack-nbg-viplite \
- --optimize 'VIP8000NANOQI_PLUS_PID0XB1'\
- --viv-sdk {netrans_path}/pnna_sdk \
- --model-quantize {name}_{quantization_type}.quantize \
- --with-input-meta {name}_inputmeta.yml \
- --target-ide-project 'linux64' \
- --postprocess-file {name}_postprocess_file.yml \
- --output-path {generate_path}/{quantization_type}"
-
-
- # 执行命令
- # print(cmd)
+ else :
+ if not os.path.exists(f"{name}_postprocess_file.yml"):
+ cmd = f"{ovxgenerator} \
+ --model {name}.json \
+ --model-data {name}.data \
+ --dtype {type_} \
+ --pack-nbg-viplite \
+ --optimize 'VIP8000NANOQI_PLUS_PID0XB1'\
+ --viv-sdk {netrans_path}/pnna_sdk \
+ --model-quantize {name}_{quantization_type}.quantize \
+ --with-input-meta {name}_inputmeta.yml \
+ --target-ide-project 'linux64' \
+ --output-path {generate_path}/{quantization_type}"
+ else:
+ cmd = f"{ovxgenerator} \
+ --model {name}.json \
+ --model-data {name}.data \
+ --dtype {type_} \
+ --pack-nbg-viplite \
+ --optimize 'VIP8000NANOQI_PLUS_PID0XB1'\
+ --viv-sdk {netrans_path}/pnna_sdk \
+ --model-quantize {name}_{quantization_type}.quantize \
+ --with-input-meta {name}_inputmeta.yml \
+ --target-ide-project 'linux64' \
+ --postprocess-file {name}_postprocess_file.yml \
+ --output-path {generate_path}/{quantization_type}"
+
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
# 检查执行结果
@@ -104,25 +102,40 @@ class Export(AttributeCopier):
# temp='wksp/temp'
# os.makedirs(temp, exist_ok=True)
- src_ngb = f'{generate_path}_nbg_viplite/network_binary.nb'
- try :
- shutil.copy(src_ngb, generate_path)
- except FileNotFoundError:
- print(f"Error: {src_ngb} is not found")
- except Exception as e :
- print(f"a error occurred : {e}")
-
- try:
- shutil.rmtree(f"{generate_path}_nbg_viplite")
- except:
- sys.exit()
+ source_dir = f"{generate_path}_nbg_viplite"
+ target_dir = generate_path
+ src_ngb = f"{source_dir}/network_binary.nb"
+ if self.profile:
+ try:
+ # 如果目标路径已存在,先删除(确保移动操作能成功)
+ if os.path.exists(target_dir):
+ shutil.rmtree(target_dir)
+ # 移动整个目录到目标位置
+ shutil.move(source_dir, target_dir)
+ # print(f"Successfully moved directory {source_dir} to {target_dir}")
+ except Exception as e:
+ sys.exit(1) # 非零退出码表示错误
+ # print(f"Error moving directory: {e}")
+ else:
+ try:
+ # 仅复制network_binary.nb文件
+ shutil.rmtree(generate_path)
+ os.mkdir(generate_path)
+ shutil.copy(src_ngb, generate_path)
- # try :
- # shutil.move(temp, generate_path )
- # except:
- # sys.exit()
- # 返回原始目录
- # os.chdir('..')
+ # print(f"Successfully copied {src_ngb} to {generate_path}")
+ except FileNotFoundError:
+ print(f"Error: {src_ngb} is not found")
+ except Exception as e:
+ print(f"Error occurred: {e}")
+
+ try:
+ # 清理源目录
+ shutil.rmtree(source_dir)
+ # print(f"Removed source directory {source_dir}")
+ except Exception as e:
+ # print(f"Error removing directory: {e}")
+ sys.exit(1) # 非零退出码表示错误
def main():
# 检查命令行参数数量
@@ -139,7 +152,7 @@ def main():
netrans_path = os.environ['NETRANS_PATH']
# netrans = os.path.join(os.environ['NETRANS_PATH'], 'pnnacc')
# 调用导出函数ss
- cla = creat_cla(netrans_path, network_name, sys.argv[2])
+ cla = create_cls(netrans_path, network_name, sys.argv[2])
func = Export(cla)
func.export_network()
diff --git a/netrans_py/netrans.py b/netrans_py/netrans.py
index 62fe560..bd3ea74 100644
--- a/netrans_py/netrans.py
+++ b/netrans_py/netrans.py
@@ -25,10 +25,10 @@ class Netrans():
pipe line
"""
def model2nbg(self, quantize_type, inputmeta=False, **kargs):
- self.load_model()
+ self.load()
self.config(inputmeta, **kargs)
- self.quantize(quantize_type)
- self.export()
+ self.quantize(quantize_type, **kargs)
+ self.export(**kargs)
"""
set netrans
@@ -104,7 +104,7 @@ class Netrans():
for db in data['input_meta']['databases']:
db['ports'][0]['preprocess']['reverse_channel'] = reverse_channel
- def load_model(self):
+ def load(self):
func = ImportModel(self)
func.import_network()
@@ -112,7 +112,7 @@ class Netrans():
func = Config(self)
func.inputmeta_gen()
- def quantize(self, quantize_type):
+ def quantize(self, quantize_type,**kargs):
self.quantize_type = quantize_type
func = Quantize(self)
func.quantize_network()
@@ -120,6 +120,10 @@ class Netrans():
def export(self, **kargs):
if kargs.get('quantize_type') :
self.quantize_type = kargs['quantize_type']
+ if kargs.get('profile') :
+ self.profile = kargs['profile']
+ else :
+ self.profile = False
func = Export(self)
func.export_network()
@@ -136,7 +140,7 @@ if __name__ == '__main__':
yolo = Netrans(network)
yolo.inputmeta_gen()
# yolo.model2nb("uint8")
- # yolo.load_model()
+ # yolo.load()
# yolo.config(mean=[0,0,0],scale=1)
# yolo.quantize('uint8')
# yolo.export()