317 lines
9.1 KiB
Python
317 lines
9.1 KiB
Python
import os
|
|
import sys
|
|
import subprocess
|
|
from utils import check_path, AttributeCopier, create_cls
|
|
|
|
def check_status(result):
|
|
"""解析命令执行情况
|
|
|
|
Args:
|
|
result (return of subprocrss.run): subprocess.run的返回值
|
|
"""
|
|
if result.returncode == 0:
|
|
print("\033[31m LOAD MODEL SUCCESS \033[0m")
|
|
else:
|
|
print(f"\033[31m ERROR: {result.stderr} \033[0m")
|
|
|
|
|
|
def import_caffe_network(name, netrans_path):
|
|
"""导入 caffe 模型
|
|
|
|
Args:
|
|
name (str): 模型名字
|
|
netrans_path (str): 模型路径
|
|
|
|
Returns:
|
|
cmd (str): 生成的pnnacc 命令行, 被subprocesses执行
|
|
"""
|
|
# 定义转换工具的路径
|
|
convert_caffe =netrans_path + " import caffe"
|
|
|
|
# 定义模型文件路径
|
|
model_json_path = f"{name}.json"
|
|
model_data_path = f"{name}.data"
|
|
model_prototxt_path = f"{name}.prototxt"
|
|
model_caffemodel_path = f"{name}.caffemodel"
|
|
|
|
# 打印转换信息
|
|
print(f"=========== Converting {name} Caffe model ===========")
|
|
|
|
# 构建转换命令
|
|
if os.path.isfile(model_caffemodel_path):
|
|
cmd = f"{convert_caffe} \
|
|
--model {model_prototxt_path} \
|
|
--weights {model_caffemodel_path} \
|
|
--output-model {model_json_path} \
|
|
--output-data {model_data_path}"
|
|
else:
|
|
print("=========== fake Caffe model data file =============")
|
|
cmd = f"{convert_caffe} \
|
|
--model {model_prototxt_path} \
|
|
--output-model {model_json_path} \
|
|
--output-data {model_data_path}"
|
|
|
|
# 执行转换命令
|
|
# print(cmd)
|
|
# os.system(cmd)
|
|
return cmd
|
|
|
|
def import_tensorflow_network(name, netrans_path):
|
|
"""导入 tensorflow 模型
|
|
|
|
Args:
|
|
name (str): 模型名字
|
|
netrans_path (str): 模型路径
|
|
|
|
Returns:
|
|
cmd (str): 生成的pnnacc 命令行, 被subprocesses执行
|
|
"""
|
|
# 定义转换工具的命令
|
|
convertf_cmd = f"{netrans_path} import tensorflow"
|
|
|
|
# 打印转换信息
|
|
print(f"=========== Converting {name} Tensorflow model ===========")
|
|
|
|
# 读取 inputs_outputs.txt 文件中的参数
|
|
with open('inputs_outputs.txt', 'r') as f:
|
|
inputs_outputs_params = f.read().strip()
|
|
|
|
# 构建转换命令
|
|
cmd = f"{convertf_cmd} \
|
|
--model {name}.pb \
|
|
--output-data {name}.data \
|
|
--output-model {name}.json \
|
|
{inputs_outputs_params}"
|
|
|
|
# 执行转换命令
|
|
# print(cmd)
|
|
return cmd
|
|
|
|
# result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
|
|
|
# 检查执行结果
|
|
# check_status(result)
|
|
|
|
def import_onnx_network(name, netrans_path):
|
|
"""导入 onnx 模型
|
|
|
|
Args:
|
|
name (str): 模型名字
|
|
netrans_path (str): 模型路径
|
|
|
|
Returns:
|
|
cmd (str): 生成的pnnacc 命令行, 被subprocesses执行
|
|
"""
|
|
# 定义转换工具的命令
|
|
convert_onnx_cmd = f"{netrans_path} import onnx"
|
|
|
|
# 打印转换信息
|
|
print(f"=========== Converting {name} ONNX model ===========")
|
|
if os.path.exists(f"{name}_outputs.txt"):
|
|
output_path = os.path.join(os.getcwd(), name+"_outputs.txt")
|
|
with open(output_path, 'r', encoding='utf-8') as file:
|
|
outputs = str(file.readline().strip())
|
|
|
|
cmd = f"{convert_onnx_cmd} \
|
|
--model {name}.onnx \
|
|
--output-model {name}.json \
|
|
--output-data {name}.data \
|
|
--outputs '{outputs}'"
|
|
else:
|
|
# 构建转换命令
|
|
cmd = f"{convert_onnx_cmd} \
|
|
--model {name}.onnx \
|
|
--output-model {name}.json \
|
|
--output-data {name}.data"
|
|
|
|
# 执行转换命令
|
|
# print(cmd)
|
|
return cmd
|
|
|
|
# result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
|
|
|
# 检查执行结果
|
|
# check_status(result)
|
|
|
|
####### TFLITE
|
|
def import_tflite_network(name, netrans_path):
|
|
"""导入 tflite 模型
|
|
|
|
Args:
|
|
name (str): 模型名字
|
|
netrans_path (str): 模型路径
|
|
|
|
Returns:
|
|
cmd (str): 生成的pnnacc 命令行, 被subprocesses执行
|
|
"""
|
|
# 定义转换工具的路径或命令
|
|
convert_tflite = f"{netrans_path} import tflite"
|
|
|
|
# 定义模型文件路径
|
|
model_json_path = f"{name}.json"
|
|
model_data_path = f"{name}.data"
|
|
model_tflite_path = f"{name}.tflite"
|
|
|
|
# 打印转换信息
|
|
print(f"=========== Converting {name} TFLite model ===========")
|
|
|
|
# 构建转换命令
|
|
cmd = f"{convert_tflite} \
|
|
--model {model_tflite_path} \
|
|
--output-model {model_json_path} \
|
|
--output-data {model_data_path}"
|
|
|
|
# 执行转换命令
|
|
# print(cmd)
|
|
return cmd
|
|
|
|
# result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
|
|
|
# 检查执行结果
|
|
# check_status(result)
|
|
|
|
|
|
def import_darknet_network(name, netrans_path):
|
|
"""导入 darknet 模型
|
|
|
|
Args:
|
|
name (str): 模型名字
|
|
netrans_path (str): 模型路径
|
|
|
|
Returns:
|
|
cmd (str): 生成的pnnacc 命令行, 被subprocesses执行
|
|
"""
|
|
# 定义转换工具的命令
|
|
convert_darknet_cmd = f"{netrans_path} import darknet"
|
|
|
|
# 打印转换信息
|
|
print(f"=========== Converting {name} darknet model ===========")
|
|
|
|
# 构建转换命令
|
|
cmd = f"{convert_darknet_cmd} \
|
|
--model {name}.cfg \
|
|
--weight {name}.weights \
|
|
--output-model {name}.json \
|
|
--output-data {name}.data"
|
|
|
|
# 执行转换命令
|
|
# print(cmd)
|
|
return cmd
|
|
|
|
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
|
|
|
# 检查执行结果
|
|
check_status(result)
|
|
|
|
def import_pytorch_network(name, netrans_path):
|
|
"""导入 pytorch 模型
|
|
|
|
Args:
|
|
name (str): 模型名字
|
|
netrans_path (str): 模型路径
|
|
|
|
Returns:
|
|
cmd (str): 生成的pnnacc 命令行, 被subprocesses执行
|
|
"""
|
|
# 定义转换工具的命令
|
|
convert_pytorch_cmd = f"{netrans_path} import pytorch"
|
|
|
|
# 打印转换信息
|
|
print(f"=========== Converting {name} pytorch model ===========")
|
|
|
|
# 读取 input_size.txt 文件中的参数
|
|
try:
|
|
with open('input_size.txt', 'r') as file:
|
|
input_size_params = ' '.join(file.readlines())
|
|
except FileNotFoundError:
|
|
print("Error: input_size.txt not found.")
|
|
sys.exit(1)
|
|
|
|
# 构建转换命令
|
|
cmd = f"{convert_pytorch_cmd} \
|
|
--model {name}.pt \
|
|
--output-model {name}.json \
|
|
--output-data {name}.data \
|
|
{input_size_params}"
|
|
|
|
# 执行转换命令
|
|
# print(cmd)
|
|
return cmd
|
|
|
|
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
|
|
|
# 检查执行结果
|
|
check_status(result)
|
|
|
|
# 使用示例
|
|
# import_tensorflow_network('model_name', '/path/to/NETRANS_PATH')
|
|
class ImportModel(AttributeCopier):
|
|
"""从实例化的 Netrans 中解析模型参数,并基于 pnnacc 导入模型
|
|
|
|
Args:
|
|
Netrans (class): 实例化的Netrans类,包含 模型信息 和 Netrans 信息
|
|
"""
|
|
def __init__(self, source_obj) -> None:
|
|
"""从实例化的 Netrans 中解析模型参数
|
|
|
|
Args:
|
|
source_obj (class): 实例化的Netrans类,包含 模型信息 和 Netrans 信息
|
|
|
|
"""
|
|
super().__init__(source_obj)
|
|
# print(source_obj.__dict__)
|
|
|
|
@check_path
|
|
def import_network(self):
|
|
"""基于 pnnacc 导入模型
|
|
|
|
Raises:
|
|
FileExistsError: 如果不存在模型文件则会报错 FileExistsError
|
|
RuntimeError: 如果执行导入失败则会报 RuntimeError
|
|
"""
|
|
if self.verbose is True :
|
|
print("begin load model")
|
|
# print(self.model_path)
|
|
print(os.getcwd())
|
|
print(f"{self.model_name}.weights")
|
|
name = self.model_name
|
|
netrans_path = self.netrans
|
|
if os.path.isfile(f"{name}.prototxt"):
|
|
cmd = import_caffe_network(name, netrans_path)
|
|
elif os.path.isfile(f"{name}.pb"):
|
|
cmd = import_tensorflow_network(name, netrans_path)
|
|
elif os.path.isfile(f"{name}.onnx"):
|
|
cmd = import_onnx_network(name, netrans_path)
|
|
elif os.path.isfile(f"{name}.tflite"):
|
|
cmd = import_tflite_network(name, netrans_path)
|
|
elif os.path.isfile(f"{name}.weights"):
|
|
cmd = import_darknet_network(name, netrans_path)
|
|
elif os.path.isfile(f"{name}.pt"):
|
|
cmd = import_pytorch_network(name, netrans_path)
|
|
else :
|
|
raise FileExistsError("Can not find suitable model files")
|
|
try :
|
|
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
|
except :
|
|
raise RuntimeError("load model failed")
|
|
# 检查执行结果
|
|
check_status(result)
|
|
# os.chdir("..")
|
|
|
|
|
|
# def main():
|
|
# if len(sys.argv) != 2 :
|
|
# print("Input a network")
|
|
# sys.exit(-1)
|
|
|
|
# network_name = sys.argv[1]
|
|
# # check_env(network_name)
|
|
|
|
# netrans_path = os.environ['NETRANS_PATH']
|
|
# # netrans = os.path.join(netrans_path, 'pnnacc')
|
|
# clas = create_cls(netrans_path, network_name,verbose=False)
|
|
# func = ImportModel(clas)
|
|
# func.import_network()
|
|
# if __name__ == "__main__":
|
|
# main()
|