forked from nudt_dsp/netrans
110 lines
3.1 KiB
Python
110 lines
3.1 KiB
Python
import sys
|
|
import os
|
|
# from functools import wraps
|
|
|
|
# def check_path(netrans, model_path):
|
|
# def decorator(func):
|
|
# @wraps(func)
|
|
# def wrapper(netrans, model_path, *args, **kargs):
|
|
# check_dir(model_path)
|
|
# check_netrans(netrans)
|
|
# if os.getcwd() != model_path :
|
|
# os.chdir(model_path)
|
|
# return func(netrans, model_path, *args, **kargs)
|
|
# return wrapper
|
|
# return decorator
|
|
|
|
def check_path(func):
|
|
""" 装饰器, 确保在工程目录运行 nertans
|
|
|
|
"""
|
|
def wrapper(cla, *args, **kargs):
|
|
check_netrans(cla.netrans)
|
|
if os.getcwd() != cla.model_path :
|
|
os.chdir(cla.model_path)
|
|
return func(cla, *args, **kargs)
|
|
return wrapper
|
|
|
|
|
|
def check_dir(network_name):
|
|
"""判断工程目录是否存在
|
|
|
|
Args:
|
|
network_name (str): 工程目录路径
|
|
|
|
Raises:
|
|
NotADirectoryError: 没有那个工程目录
|
|
"""
|
|
if not os.path.exists(network_name):
|
|
raise NotADirectoryError(
|
|
f"Directory not found: {network_name}"
|
|
)
|
|
# print(f"Directory {network_name} does not exist !")
|
|
# sys.exit(-1)
|
|
os.chdir(network_name)
|
|
|
|
def check_netrans(netrans):
|
|
"""判断 netrans 是否配置成功
|
|
|
|
Args:
|
|
netrans (str, bool): _netrans 路径, 如果没有配置(默认为False)会去环境变量里找
|
|
|
|
Raises:
|
|
NotADirectoryError: 找不到 Netrans 会返回 NotADirectoryError
|
|
"""
|
|
if netrans != None and os.path.exists(netrans) is True:
|
|
return
|
|
if 'NETRANS_PATH' in os.environ :
|
|
return
|
|
raise NotADirectoryError(
|
|
f"Netrans not found: {netrans}"
|
|
)
|
|
|
|
|
|
def remove_history_file(name):
|
|
os.chdir(name)
|
|
if os.path.isfile(f"{name}.json"):
|
|
os.remove(f"{name}.json")
|
|
if os.path.isfile(f"{name}.data"):
|
|
os.remove(f"{name}.data")
|
|
os.chdir('..')
|
|
|
|
def check_env(name):
|
|
check_dir(name)
|
|
# check_netrans()
|
|
# remove_history_file(name)
|
|
|
|
|
|
class AttributeCopier:
|
|
"""快速解析复制 Netrans 信息
|
|
"""
|
|
def __init__(self, source_obj) -> None:
|
|
self.copy_attribute_name(source_obj)
|
|
|
|
def copy_attribute_name(self, source_obj):
|
|
for attribute_name in self._get_attribute_names(source_obj):
|
|
setattr(self, attribute_name, getattr(source_obj, attribute_name))
|
|
|
|
@staticmethod
|
|
def _get_attribute_names(source_obj):
|
|
return source_obj.__dict__.keys()
|
|
|
|
class create_cls(): #dataclass @netrans_params
|
|
"""快速测试时候模拟实例化Netrans"""
|
|
def __init__(self, netrans_path, name, quantized_type = 'uint8',verbose=False) -> None:
|
|
self.netrans_path = netrans_path
|
|
self.netrans = os.path.join(self.netrans_path, 'pnnacc')
|
|
self.model_name=self.model_path = name
|
|
self.model_path = os.path.abspath(self.model_path)
|
|
self.verbose=verbose
|
|
self.quantize_type = quantized_type
|
|
self.profile = False
|
|
|
|
|
|
# if __name__ == "__main__":
|
|
# dir_name = "yolo"
|
|
# os.mkdir(dir_name)
|
|
# check_dir(dir_name)
|
|
|
|
|