netrans/netrans_py/quantize.py

106 lines
3.6 KiB
Python

import os
import sys
from utils import check_path, AttributeCopier, create_cls
class Quantize(AttributeCopier):
"""
解析 Netrans 参数,基于 pnnacc 量化模型
Args:
cla (class): 实例化以后的 Netrans 类,需要解析里面包含的参数
"""
def __init__(self, source_obj) -> None:
"""
从 Netrans 类中获取模型信息
Args:
source_obj (class): 实例化以后的 Netrans 类,需要解析里面包含的参数
"""
super().__init__(source_obj)
@check_path
def quantize_network(self):
"""基于 pnnacc 量化模型
"""
netrans = self.netrans
quantized_type = self.quantize_type
name = self.model_name
# check_env(name)
# print(os.getcwd())
netrans += " quantize"
# 根据量化类型设置量化参数
if quantized_type == 'float':
print("=========== do not need quantized===========")
return
elif quantized_type == 'uint8':
quantization_type = "asymmetric_affine"
elif quantized_type == 'int8':
quantization_type = "dynamic_fixed_point-8"
elif quantized_type == 'int16':
quantization_type = "dynamic_fixed_point-16"
else:
print("=========== wrong quantization_type ! ( uint8 / int8 / int16 )===========")
return
# 输出量化信息
print(" =======================================================================")
print(f" ==== Start Quantizing {name} model with type of {quantization_type} ===")
print(" =======================================================================")
current_directory = os.getcwd()
txt_path = current_directory+"/dataset.txt"
with open(txt_path, 'r', encoding='utf-8') as file:
num_lines = len(file.readlines())
# 移除已存在的量化文件
quantize_file = f"{name}_{quantization_type}.quantize"
if os.path.exists(quantize_file):
print(f"\033[31m rm {quantize_file} \033[0m")
os.remove(quantize_file)
# 构建并执行量化命令
cmd = f"{netrans} \
--batch-size 1 \
--qtype {quantized_type} \
--rebuild \
--quantizer {quantization_type.split('-')[0]} \
--model-quantize {quantize_file} \
--model {name}.json \
--model-data {name}.data \
--with-input-meta {name}_inputmeta.yml \
--device CPU \
--algorithm kl_divergence \
--iterations {num_lines}"
os.system(cmd)
# 检查量化结果
if os.path.exists(quantize_file):
print("\033[31m QUANTIZED SUCCESS \033[0m")
else:
print("\033[31m ERROR ! \033[0m")
# def main():
# # 检查命令行参数数量
# if len(sys.argv) < 3:
# print("Input a network name and quantized type ( uint8 / int8 / int16 )")
# sys.exit(-1)
# # 检查网络目录是否存在
# network_name = sys.argv[1]
# # 定义 netrans 路径
# # netrans = os.path.join(os.environ['NETRANS_PATH'], 'pnnacc')
# # network_name = sys.argv[1]
# # check_env(network_name)
# netrans_path = os.environ['NETRANS_PATH']
# # netrans = os.path.join(netrans_path, 'pnnacc')
# quantize_type = sys.argv[2]
# cla = create_cls(netrans_path, network_name,quantize_type)
# # 调用量化函数
# run = Quantize(cla)
# run.quantize_network()
# if __name__ == "__main__":
# main()