update netrans_cli

This commit is contained in:
xujiao 2025-04-09 10:45:36 +08:00
parent 29b52e7aad
commit df7d9473f9
2 changed files with 60 additions and 1 deletions

60
netrans_cli/example.py Executable file
View File

@ -0,0 +1,60 @@
#!/usr/bin/env python3
import argparse
from netrans import Netrans
def main():
# 创建参数解析器
parser = argparse.ArgumentParser(
description='神经网络模型转换工具',
formatter_class=argparse.ArgumentDefaultsHelpFormatter # 自动显示默认值
)
# 必填位置参数
parser.add_argument(
'model_path',
type=str,
help='输入模型路径(必须参数)'
)
# 可选参数组
quant_group = parser.add_argument_group('量化参数')
quant_group.add_argument(
'-q', '--quantize_type',
type=str,
choices=['uint8', 'int8', 'int16', 'float'],
default='uint8',
metavar='TYPE',
help='量化类型(可选值:%(choices)s'
)
quant_group.add_argument(
'-m', '--mean',
type=int,
default=0,
help='归一化均值(默认:%(default)s'
)
quant_group.add_argument(
'-s', '--scale',
type=float,
default=1.0,
help='量化缩放系数(默认:%(default)s'
)
# 解析参数
args = parser.parse_args()
# 执行模型转换
try:
model = Netrans(model_path=args.model_path)
model.model2nbg(
quantize_type=args.quantize_type,
mean=args.mean,
scale=args.scale
)
print(f"模型 {args.model_path} 转换成功")
except FileNotFoundError:
print(f"错误:模型文件 {args.model_path} 不存在")
exit(1)
if __name__ == "__main__":
main()

View File

@ -82,7 +82,6 @@ function export_network()
--viv-sdk ${NETRANS_PATH}/pnna_sdk \
--output-path ${generate_path}/${NAME}_${quantization_type}"
fi
echo $cmd
if [${VERIFY}='TRUE']; then
echo $cmd
fi