netrans/netrans_py/config.py

60 lines
1.7 KiB
Python

import os
import sys
from utils import check_path, AttributeCopier, create_cls
import subprocess
class Config(AttributeCopier):
"""从实例化的 Netrans 中解析模型参数,并基于pnnacc 生成配置文件模板
Args:
Netrans (class): 实例化的Netrans类,包含 模型信息 和 Netrans 信息
"""
def __init__(self, source_obj) -> None:
"""从实例化的 Netrans 中解析模型参数
Args:
source_obj (class): 实例化的Netrans类,包含 模型信息 和 Netrans 信息
"""
super().__init__(source_obj)
@check_path
def inputmeta_gen(self):
"""生成配置文件模板
Return:
None
"""
netrans_path = self.netrans
network_name = self.model_name
# 进入网络名称指定的目录
# os.chdir(network_name)
# check_env(network_name)
# 执行 pegasus 命令
cmd = f"{netrans_path} generate inputmeta --model {network_name}.json --separated-database"
try :
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
except :
raise RuntimeError('config failed')
# os.chdir("..")
# def main():
# # 检查命令行参数数量是否正确
# if len(sys.argv) != 2:
# print("Enter a network name!")
# sys.exit(2)
# # 检查提供的目录是否存在
# network_name = sys.argv[1]
# # 构建 netrans 可执行文件的路径
# netrans_path =os.getenv('NETRANS_PATH')
# cla = create_cls(netrans_path, network_name)
# func = InputmetaGen(cla)
# func.inputmeta_gen()
# if __name__ == '__main__':
# main()