+import os
+import sys
+import subprocess
+from utils import check_path, AttributeCopier, create_cls
+
+
+
[文档]
+
def check_status(result):
+
"""解析命令执行情况
+
+
Args:
+
result (return of subprocrss.run): subprocess.run的返回值
+
"""
+
if result.returncode == 0:
+
print("\033[31m LOAD MODEL SUCCESS \033[0m")
+
else:
+
print(f"\033[31m ERROR: {result.stderr} \033[0m")
+
+
+
+
+
[文档]
+
def import_caffe_network(name, netrans_path):
+
"""导入 caffe 模型
+
+
Args:
+
name (str): 模型名字
+
netrans_path (str): 模型路径
+
+
Returns:
+
cmd (str): 生成的pnnacc 命令行, 被subprocesses执行
+
"""
+
# 定义转换工具的路径
+
convert_caffe =netrans_path + " import caffe"
+
+
# 定义模型文件路径
+
model_json_path = f"{name}.json"
+
model_data_path = f"{name}.data"
+
model_prototxt_path = f"{name}.prototxt"
+
model_caffemodel_path = f"{name}.caffemodel"
+
+
# 打印转换信息
+
print(f"=========== Converting {name} Caffe model ===========")
+
+
# 构建转换命令
+
if os.path.isfile(model_caffemodel_path):
+
cmd = f"{convert_caffe} \
+
--model {model_prototxt_path} \
+
--weights {model_caffemodel_path} \
+
--output-model {model_json_path} \
+
--output-data {model_data_path}"
+
else:
+
print("=========== fake Caffe model data file =============")
+
cmd = f"{convert_caffe} \
+
--model {model_prototxt_path} \
+
--output-model {model_json_path} \
+
--output-data {model_data_path}"
+
+
# 执行转换命令
+
# print(cmd)
+
# os.system(cmd)
+
return cmd
+
+
+
+
[文档]
+
def import_tensorflow_network(name, netrans_path):
+
"""导入 tensorflow 模型
+
+
Args:
+
name (str): 模型名字
+
netrans_path (str): 模型路径
+
+
Returns:
+
cmd (str): 生成的pnnacc 命令行, 被subprocesses执行
+
"""
+
# 定义转换工具的命令
+
convertf_cmd = f"{netrans_path} import tensorflow"
+
+
# 打印转换信息
+
print(f"=========== Converting {name} Tensorflow model ===========")
+
+
# 读取 inputs_outputs.txt 文件中的参数
+
with open('inputs_outputs.txt', 'r') as f:
+
inputs_outputs_params = f.read().strip()
+
+
# 构建转换命令
+
cmd = f"{convertf_cmd} \
+
--model {name}.pb \
+
--output-data {name}.data \
+
--output-model {name}.json \
+
{inputs_outputs_params}"
+
+
# 执行转换命令
+
# print(cmd)
+
return cmd
+
+
+ # result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
+
+ # 检查执行结果
+ # check_status(result)
+
+
+
[文档]
+
def import_onnx_network(name, netrans_path):
+
"""导入 onnx 模型
+
+
Args:
+
name (str): 模型名字
+
netrans_path (str): 模型路径
+
+
Returns:
+
cmd (str): 生成的pnnacc 命令行, 被subprocesses执行
+
"""
+
# 定义转换工具的命令
+
convert_onnx_cmd = f"{netrans_path} import onnx"
+
+
# 打印转换信息
+
print(f"=========== Converting {name} ONNX model ===========")
+
if os.path.exists(f"{name}_outputs.txt"):
+
output_path = os.path.join(os.getcwd(), name+"_outputs.txt")
+
with open(output_path, 'r', encoding='utf-8') as file:
+
outputs = str(file.readline().strip())
+
+
cmd = f"{convert_onnx_cmd} \
+
--model {name}.onnx \
+
--output-model {name}.json \
+
--output-data {name}.data \
+
--outputs '{outputs}'"
+
else:
+
# 构建转换命令
+
cmd = f"{convert_onnx_cmd} \
+
--model {name}.onnx \
+
--output-model {name}.json \
+
--output-data {name}.data"
+
+
# 执行转换命令
+
# print(cmd)
+
return cmd
+
+
+ # result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
+
+ # 检查执行结果
+ # check_status(result)
+
+####### TFLITE
+
+
[文档]
+
def import_tflite_network(name, netrans_path):
+
"""导入 tflite 模型
+
+
Args:
+
name (str): 模型名字
+
netrans_path (str): 模型路径
+
+
Returns:
+
cmd (str): 生成的pnnacc 命令行, 被subprocesses执行
+
"""
+
# 定义转换工具的路径或命令
+
convert_tflite = f"{netrans_path} import tflite"
+
+
# 定义模型文件路径
+
model_json_path = f"{name}.json"
+
model_data_path = f"{name}.data"
+
model_tflite_path = f"{name}.tflite"
+
+
# 打印转换信息
+
print(f"=========== Converting {name} TFLite model ===========")
+
+
# 构建转换命令
+
cmd = f"{convert_tflite} \
+
--model {model_tflite_path} \
+
--output-model {model_json_path} \
+
--output-data {model_data_path}"
+
+
# 执行转换命令
+
# print(cmd)
+
return cmd
+
+
+ # result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
+
+ # 检查执行结果
+ # check_status(result)
+
+
+
+
[文档]
+
def import_darknet_network(name, netrans_path):
+
"""导入 darknet 模型
+
+
Args:
+
name (str): 模型名字
+
netrans_path (str): 模型路径
+
+
Returns:
+
cmd (str): 生成的pnnacc 命令行, 被subprocesses执行
+
"""
+
# 定义转换工具的命令
+
convert_darknet_cmd = f"{netrans_path} import darknet"
+
+
# 打印转换信息
+
print(f"=========== Converting {name} darknet model ===========")
+
+
# 构建转换命令
+
cmd = f"{convert_darknet_cmd} \
+
--model {name}.cfg \
+
--weight {name}.weights \
+
--output-model {name}.json \
+
--output-data {name}.data"
+
+
# 执行转换命令
+
# print(cmd)
+
return cmd
+
+
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
+
+
# 检查执行结果
+
check_status(result)
+
+
+
+
[文档]
+
def import_pytorch_network(name, netrans_path):
+
"""导入 pytorch 模型
+
+
Args:
+
name (str): 模型名字
+
netrans_path (str): 模型路径
+
+
Returns:
+
cmd (str): 生成的pnnacc 命令行, 被subprocesses执行
+
"""
+
# 定义转换工具的命令
+
convert_pytorch_cmd = f"{netrans_path} import pytorch"
+
+
# 打印转换信息
+
print(f"=========== Converting {name} pytorch model ===========")
+
+
# 读取 input_size.txt 文件中的参数
+
try:
+
with open('input_size.txt', 'r') as file:
+
input_size_params = ' '.join(file.readlines())
+
except FileNotFoundError:
+
print("Error: input_size.txt not found.")
+
sys.exit(1)
+
+
# 构建转换命令
+
cmd = f"{convert_pytorch_cmd} \
+
--model {name}.pt \
+
--output-model {name}.json \
+
--output-data {name}.data \
+
{input_size_params}"
+
+
# 执行转换命令
+
# print(cmd)
+
return cmd
+
+
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
+
+
# 检查执行结果
+
check_status(result)
+
+
+# 使用示例
+# import_tensorflow_network('model_name', '/path/to/NETRANS_PATH')
+
+
[文档]
+
class ImportModel(AttributeCopier):
+
"""从实例化的 Netrans 中解析模型参数,并基于 pnnacc 导入模型
+
+
Args:
+
Netrans (class): 实例化的Netrans类,包含 模型信息 和 Netrans 信息
+
"""
+
def __init__(self, source_obj) -> None:
+
"""从实例化的 Netrans 中解析模型参数
+
+
Args:
+
source_obj (class): 实例化的Netrans类,包含 模型信息 和 Netrans 信息
+
+
"""
+
super().__init__(source_obj)
+
# print(source_obj.__dict__)
+
+
@check_path
+
def import_network(self):
+
"""基于 pnnacc 导入模型
+
+
Raises:
+
FileExistsError: 如果不存在模型文件则会报错 FileExistsError
+
RuntimeError: 如果执行导入失败则会报 RuntimeError
+
"""
+
if self.verbose is True :
+
print("begin load model")
+
# print(self.model_path)
+
print(os.getcwd())
+
print(f"{self.model_name}.weights")
+
name = self.model_name
+
netrans_path = self.netrans
+
if os.path.isfile(f"{name}.prototxt"):
+
cmd = import_caffe_network(name, netrans_path)
+
elif os.path.isfile(f"{name}.pb"):
+
cmd = import_tensorflow_network(name, netrans_path)
+
elif os.path.isfile(f"{name}.onnx"):
+
cmd = import_onnx_network(name, netrans_path)
+
elif os.path.isfile(f"{name}.tflite"):
+
cmd = import_tflite_network(name, netrans_path)
+
elif os.path.isfile(f"{name}.weights"):
+
cmd = import_darknet_network(name, netrans_path)
+
elif os.path.isfile(f"{name}.pt"):
+
cmd = import_pytorch_network(name, netrans_path)
+
else :
+
raise FileExistsError("Can not find suitable model files")
+
try :
+
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
+
except :
+
raise RuntimeError("load model failed")
+
# 检查执行结果
+
check_status(result)
+
+ # os.chdir("..")
+
+
+# def main():
+# if len(sys.argv) != 2 :
+# print("Input a network")
+# sys.exit(-1)
+
+# network_name = sys.argv[1]
+# # check_env(network_name)
+
+# netrans_path = os.environ['NETRANS_PATH']
+# # netrans = os.path.join(netrans_path, 'pnnacc')
+# clas = create_cls(netrans_path, network_name,verbose=False)
+# func = ImportModel(clas)
+# func.import_network()
+# if __name__ == "__main__":
+# main()
+