update netrans_cli
This commit is contained in:
parent
29b52e7aad
commit
df7d9473f9
|
@ -0,0 +1,60 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
from netrans import Netrans
|
||||
|
||||
def main():
|
||||
# 创建参数解析器
|
||||
parser = argparse.ArgumentParser(
|
||||
description='神经网络模型转换工具',
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter # 自动显示默认值
|
||||
)
|
||||
|
||||
# 必填位置参数
|
||||
parser.add_argument(
|
||||
'model_path',
|
||||
type=str,
|
||||
help='输入模型路径(必须参数)'
|
||||
)
|
||||
|
||||
# 可选参数组
|
||||
quant_group = parser.add_argument_group('量化参数')
|
||||
quant_group.add_argument(
|
||||
'-q', '--quantize_type',
|
||||
type=str,
|
||||
choices=['uint8', 'int8', 'int16', 'float'],
|
||||
default='uint8',
|
||||
metavar='TYPE',
|
||||
help='量化类型(可选值:%(choices)s)'
|
||||
)
|
||||
quant_group.add_argument(
|
||||
'-m', '--mean',
|
||||
type=int,
|
||||
default=0,
|
||||
help='归一化均值(默认:%(default)s)'
|
||||
)
|
||||
quant_group.add_argument(
|
||||
'-s', '--scale',
|
||||
type=float,
|
||||
default=1.0,
|
||||
help='量化缩放系数(默认:%(default)s)'
|
||||
)
|
||||
|
||||
# 解析参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 执行模型转换
|
||||
try:
|
||||
model = Netrans(model_path=args.model_path)
|
||||
model.model2nbg(
|
||||
quantize_type=args.quantize_type,
|
||||
mean=args.mean,
|
||||
scale=args.scale
|
||||
)
|
||||
print(f"模型 {args.model_path} 转换成功")
|
||||
except FileNotFoundError:
|
||||
print(f"错误:模型文件 {args.model_path} 不存在")
|
||||
exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -82,7 +82,6 @@ function export_network()
|
|||
--viv-sdk ${NETRANS_PATH}/pnna_sdk \
|
||||
--output-path ${generate_path}/${NAME}_${quantization_type}"
|
||||
fi
|
||||
echo $cmd
|
||||
if [${VERIFY}='TRUE']; then
|
||||
echo $cmd
|
||||
fi
|
||||
|
|
Loading…
Reference in New Issue