forked from nudt_dsp/netrans
60 lines
1.7 KiB
Python
60 lines
1.7 KiB
Python
|
|
import os
|
|
import sys
|
|
from utils import check_path, AttributeCopier, create_cls
|
|
import subprocess
|
|
|
|
class Config(AttributeCopier):
|
|
"""从实例化的 Netrans 中解析模型参数,并基于pnnacc 生成配置文件模板
|
|
|
|
Args:
|
|
Netrans (class): 实例化的Netrans类,包含 模型信息 和 Netrans 信息
|
|
"""
|
|
def __init__(self, source_obj) -> None:
|
|
"""从实例化的 Netrans 中解析模型参数
|
|
|
|
Args:
|
|
source_obj (class): 实例化的Netrans类,包含 模型信息 和 Netrans 信息
|
|
|
|
"""
|
|
super().__init__(source_obj)
|
|
|
|
@check_path
|
|
def inputmeta_gen(self):
|
|
"""生成配置文件模板
|
|
|
|
Return:
|
|
None
|
|
"""
|
|
netrans_path = self.netrans
|
|
network_name = self.model_name
|
|
# 进入网络名称指定的目录
|
|
# os.chdir(network_name)
|
|
# check_env(network_name)
|
|
|
|
# 执行 pegasus 命令
|
|
cmd = f"{netrans_path} generate inputmeta --model {network_name}.json --separated-database"
|
|
try :
|
|
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
|
except :
|
|
raise RuntimeError('config failed')
|
|
# os.chdir("..")
|
|
|
|
# def main():
|
|
|
|
# # 检查命令行参数数量是否正确
|
|
# if len(sys.argv) != 2:
|
|
# print("Enter a network name!")
|
|
# sys.exit(2)
|
|
|
|
# # 检查提供的目录是否存在
|
|
# network_name = sys.argv[1]
|
|
# # 构建 netrans 可执行文件的路径
|
|
# netrans_path =os.getenv('NETRANS_PATH')
|
|
# cla = create_cls(netrans_path, network_name)
|
|
# func = InputmetaGen(cla)
|
|
# func.inputmeta_gen()
|
|
|
|
|
|
# if __name__ == '__main__':
|
|
# main() |