netrans/netrans_py/import_model.py

317 lines
9.1 KiB
Python

import os
import sys
import subprocess
from utils import check_path, AttributeCopier, create_cls
def check_status(result):
"""解析命令执行情况
Args:
result (return of subprocrss.run): subprocess.run的返回值
"""
if result.returncode == 0:
print("\033[31m LOAD MODEL SUCCESS \033[0m")
else:
print(f"\033[31m ERROR: {result.stderr} \033[0m")
def import_caffe_network(name, netrans_path):
"""导入 caffe 模型
Args:
name (str): 模型名字
netrans_path (str): 模型路径
Returns:
cmd (str): 生成的pnnacc 命令行, 被subprocesses执行
"""
# 定义转换工具的路径
convert_caffe =netrans_path + " import caffe"
# 定义模型文件路径
model_json_path = f"{name}.json"
model_data_path = f"{name}.data"
model_prototxt_path = f"{name}.prototxt"
model_caffemodel_path = f"{name}.caffemodel"
# 打印转换信息
print(f"=========== Converting {name} Caffe model ===========")
# 构建转换命令
if os.path.isfile(model_caffemodel_path):
cmd = f"{convert_caffe} \
--model {model_prototxt_path} \
--weights {model_caffemodel_path} \
--output-model {model_json_path} \
--output-data {model_data_path}"
else:
print("=========== fake Caffe model data file =============")
cmd = f"{convert_caffe} \
--model {model_prototxt_path} \
--output-model {model_json_path} \
--output-data {model_data_path}"
# 执行转换命令
# print(cmd)
# os.system(cmd)
return cmd
def import_tensorflow_network(name, netrans_path):
"""导入 tensorflow 模型
Args:
name (str): 模型名字
netrans_path (str): 模型路径
Returns:
cmd (str): 生成的pnnacc 命令行, 被subprocesses执行
"""
# 定义转换工具的命令
convertf_cmd = f"{netrans_path} import tensorflow"
# 打印转换信息
print(f"=========== Converting {name} Tensorflow model ===========")
# 读取 inputs_outputs.txt 文件中的参数
with open('inputs_outputs.txt', 'r') as f:
inputs_outputs_params = f.read().strip()
# 构建转换命令
cmd = f"{convertf_cmd} \
--model {name}.pb \
--output-data {name}.data \
--output-model {name}.json \
{inputs_outputs_params}"
# 执行转换命令
# print(cmd)
return cmd
# result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
# 检查执行结果
# check_status(result)
def import_onnx_network(name, netrans_path):
"""导入 onnx 模型
Args:
name (str): 模型名字
netrans_path (str): 模型路径
Returns:
cmd (str): 生成的pnnacc 命令行, 被subprocesses执行
"""
# 定义转换工具的命令
convert_onnx_cmd = f"{netrans_path} import onnx"
# 打印转换信息
print(f"=========== Converting {name} ONNX model ===========")
if os.path.exists(f"{name}_outputs.txt"):
output_path = os.path.join(os.getcwd(), name+"_outputs.txt")
with open(output_path, 'r', encoding='utf-8') as file:
outputs = str(file.readline().strip())
cmd = f"{convert_onnx_cmd} \
--model {name}.onnx \
--output-model {name}.json \
--output-data {name}.data \
--outputs '{outputs}'"
else:
# 构建转换命令
cmd = f"{convert_onnx_cmd} \
--model {name}.onnx \
--output-model {name}.json \
--output-data {name}.data"
# 执行转换命令
# print(cmd)
return cmd
# result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
# 检查执行结果
# check_status(result)
####### TFLITE
def import_tflite_network(name, netrans_path):
"""导入 tflite 模型
Args:
name (str): 模型名字
netrans_path (str): 模型路径
Returns:
cmd (str): 生成的pnnacc 命令行, 被subprocesses执行
"""
# 定义转换工具的路径或命令
convert_tflite = f"{netrans_path} import tflite"
# 定义模型文件路径
model_json_path = f"{name}.json"
model_data_path = f"{name}.data"
model_tflite_path = f"{name}.tflite"
# 打印转换信息
print(f"=========== Converting {name} TFLite model ===========")
# 构建转换命令
cmd = f"{convert_tflite} \
--model {model_tflite_path} \
--output-model {model_json_path} \
--output-data {model_data_path}"
# 执行转换命令
# print(cmd)
return cmd
# result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
# 检查执行结果
# check_status(result)
def import_darknet_network(name, netrans_path):
"""导入 darknet 模型
Args:
name (str): 模型名字
netrans_path (str): 模型路径
Returns:
cmd (str): 生成的pnnacc 命令行, 被subprocesses执行
"""
# 定义转换工具的命令
convert_darknet_cmd = f"{netrans_path} import darknet"
# 打印转换信息
print(f"=========== Converting {name} darknet model ===========")
# 构建转换命令
cmd = f"{convert_darknet_cmd} \
--model {name}.cfg \
--weight {name}.weights \
--output-model {name}.json \
--output-data {name}.data"
# 执行转换命令
# print(cmd)
return cmd
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
# 检查执行结果
check_status(result)
def import_pytorch_network(name, netrans_path):
"""导入 pytorch 模型
Args:
name (str): 模型名字
netrans_path (str): 模型路径
Returns:
cmd (str): 生成的pnnacc 命令行, 被subprocesses执行
"""
# 定义转换工具的命令
convert_pytorch_cmd = f"{netrans_path} import pytorch"
# 打印转换信息
print(f"=========== Converting {name} pytorch model ===========")
# 读取 input_size.txt 文件中的参数
try:
with open('input_size.txt', 'r') as file:
input_size_params = ' '.join(file.readlines())
except FileNotFoundError:
print("Error: input_size.txt not found.")
sys.exit(1)
# 构建转换命令
cmd = f"{convert_pytorch_cmd} \
--model {name}.pt \
--output-model {name}.json \
--output-data {name}.data \
{input_size_params}"
# 执行转换命令
# print(cmd)
return cmd
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
# 检查执行结果
check_status(result)
# 使用示例
# import_tensorflow_network('model_name', '/path/to/NETRANS_PATH')
class ImportModel(AttributeCopier):
"""从实例化的 Netrans 中解析模型参数,并基于 pnnacc 导入模型
Args:
Netrans (class): 实例化的Netrans类,包含 模型信息 和 Netrans 信息
"""
def __init__(self, source_obj) -> None:
"""从实例化的 Netrans 中解析模型参数
Args:
source_obj (class): 实例化的Netrans类,包含 模型信息 和 Netrans 信息
"""
super().__init__(source_obj)
# print(source_obj.__dict__)
@check_path
def import_network(self):
"""基于 pnnacc 导入模型
Raises:
FileExistsError: 如果不存在模型文件则会报错 FileExistsError
RuntimeError: 如果执行导入失败则会报 RuntimeError
"""
if self.verbose is True :
print("begin load model")
# print(self.model_path)
print(os.getcwd())
print(f"{self.model_name}.weights")
name = self.model_name
netrans_path = self.netrans
if os.path.isfile(f"{name}.prototxt"):
cmd = import_caffe_network(name, netrans_path)
elif os.path.isfile(f"{name}.pb"):
cmd = import_tensorflow_network(name, netrans_path)
elif os.path.isfile(f"{name}.onnx"):
cmd = import_onnx_network(name, netrans_path)
elif os.path.isfile(f"{name}.tflite"):
cmd = import_tflite_network(name, netrans_path)
elif os.path.isfile(f"{name}.weights"):
cmd = import_darknet_network(name, netrans_path)
elif os.path.isfile(f"{name}.pt"):
cmd = import_pytorch_network(name, netrans_path)
else :
raise FileExistsError("Can not find suitable model files")
try :
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
except :
raise RuntimeError("load model failed")
# 检查执行结果
check_status(result)
# os.chdir("..")
# def main():
# if len(sys.argv) != 2 :
# print("Input a network")
# sys.exit(-1)
# network_name = sys.argv[1]
# # check_env(network_name)
# netrans_path = os.environ['NETRANS_PATH']
# # netrans = os.path.join(netrans_path, 'pnnacc')
# clas = create_cls(netrans_path, network_name,verbose=False)
# func = ImportModel(clas)
# func.import_network()
# if __name__ == "__main__":
# main()