forked from nudt_dsp/netrans
209 lines
8.2 KiB
Python
209 lines
8.2 KiB
Python
import pytest
|
||
from pathlib import Path
|
||
from unittest.mock import Mock, patch
|
||
from netrans import Netrans
|
||
import shutil
|
||
@pytest.fixture
|
||
def mock_model_dir(tmp_path):
|
||
"""创建复制真实示例目录内容的虚拟模型目录"""
|
||
# 获取示例目录的绝对路径
|
||
current_file = Path(__file__)
|
||
examples_dir = current_file.parent.parent.parent / "examples" # 调整路径层级
|
||
source_dir = examples_dir / "darknet/yolov4_tiny"
|
||
|
||
# 验证源目录存在
|
||
if not source_dir.exists():
|
||
raise FileNotFoundError(f"示例目录未找到: {source_dir}")
|
||
|
||
# 创建目标目录路径
|
||
model_dir = tmp_path / "yolov4_tiny"
|
||
|
||
# 执行目录复制
|
||
shutil.copytree(
|
||
src=str(source_dir),
|
||
dst=str(model_dir),
|
||
symlinks=True,
|
||
ignore_dangling_symlinks=True
|
||
)
|
||
|
||
return model_dir
|
||
|
||
@pytest.fixture
|
||
def mock_netrans_path(tmp_path):
|
||
"""创建复制真实示例目录内容的虚拟模型目录"""
|
||
# 获取示例目录的绝对路径
|
||
current_file = Path(__file__)
|
||
source_dir = current_file.parent.parent.parent / "bin" # 调整路径层级
|
||
# 创建目标目录路径
|
||
netrans_path = tmp_path / "bin"
|
||
|
||
# 执行目录复制
|
||
shutil.copytree(
|
||
src=str(source_dir),
|
||
dst=str(netrans_path),
|
||
symlinks=True,
|
||
ignore_dangling_symlinks=True
|
||
)
|
||
return netrans_path
|
||
|
||
@pytest.fixture
|
||
def netrans_instance(mock_model_dir, mock_netrans_path):
|
||
return Netrans(str(mock_model_dir), netrans=str(mock_netrans_path))
|
||
|
||
class TestNetransInitialization:
|
||
|
||
def test_init_with_default_netrans(self, mock_model_dir):
|
||
# 测试能否找到默认的netrans(.bashrc NETRANS_PATH)
|
||
net = Netrans(str(mock_model_dir))
|
||
assert net.model_path == str(mock_model_dir)
|
||
assert hasattr(net, 'netrans_path')
|
||
|
||
def test_init_with_custom_netrans(self, mock_model_dir, mock_netrans_path):
|
||
# 测试传参 netrans
|
||
net = Netrans(str(mock_model_dir), netrans= str(mock_netrans_path))
|
||
assert net.netrans_path == str(mock_netrans_path)
|
||
|
||
def test_init_with_invalid_model_path(self):
|
||
# 测试给定非法路径报错
|
||
with pytest.raises(FileNotFoundError):
|
||
Netrans("invalid/path")
|
||
|
||
def test_load_success(self, netrans_instance, tmp_path):
|
||
# 测试模型导入功能
|
||
with patch('subprocess.run') as mock_run:
|
||
netrans_instance.load()
|
||
mock_run.assert_called_once()
|
||
|
||
# 验证生成的文件
|
||
output_dir = Path(netrans_instance.model_path)
|
||
assert (output_dir / f"{netrans_instance.model_name}.json").exists()
|
||
assert (output_dir / f"{netrans_instance.model_name}.data").exists()
|
||
|
||
def test_load_failure(self, netrans_instance):
|
||
with patch('subprocess.run', side_effect=Exception("Process error")):
|
||
with pytest.raises(RuntimeError):
|
||
netrans_instance.load()
|
||
|
||
# @pytest.mark.parametrize("inputmeta_param, expected", [
|
||
# (False, "generate_template"),
|
||
# (True, "auto_detect"),
|
||
# (netrans_instance.model_path/f"{netrans_instance.model_name}_inputmeta.yml", "use_custom")
|
||
# ])
|
||
# def test_config_gen(self, netrans_instance, inputmeta_param, expected):
|
||
# with patch.object(Netrans, '_handle_inputmeta') as mock_method:
|
||
# netrans_instance.config(inputmeta=inputmeta_param)
|
||
# mock_method.assert_called_with(expected)
|
||
|
||
def test_config_gen_inputmeta(self, netrans_instance):
|
||
|
||
with patch('subprocess.run') as mock_run:
|
||
netrans_instance.config()
|
||
mock_run.assert_called_once()
|
||
assert (Path(netrans_instance.model_path) / f"{netrans_instance.model_name}_inputmeta.yml").exists()
|
||
|
||
|
||
def test_config_auto_find_inputmeta(self, netrans_instance):
|
||
|
||
with patch('subprocess.run') as mock_run:
|
||
netrans_instance.config(True)
|
||
|
||
assert netrans_instance.input_meta == str(Path(netrans_instance.model_path) / f"{netrans_instance.model_name}_inputmeta.yml")
|
||
|
||
def test_config_use_costume_inputmeta(self, netrans_instance):
|
||
|
||
inputmeta = str(Path(netrans_instance.model_path) / f"{netrans_instance.model_name}_inputmeta.yml")
|
||
cp_file = inputmeta+'.tmp.yml'
|
||
shutil.copy2(str(inputmeta), str(cp_file))
|
||
|
||
netrans_instance.config(inputmeta=cp_file)
|
||
assert netrans_instance.input_meta == cp_file
|
||
|
||
|
||
def test_config_with_invalid_model_path(self, netrans_instance):
|
||
# with patch('subprocess.run', side_effect=Exception("Process error")):
|
||
# with pytest.raises(RuntimeError):
|
||
# netrans_instance.config(False)
|
||
# 测试给定非法路径报错
|
||
with pytest.raises(FileExistsError):
|
||
inputmeta="invalid/path"
|
||
print(isinstance(inputmeta, str))
|
||
netrans_instance.config(inputmeta="invalid/path")
|
||
|
||
def test_config_parameter_combinations(self, netrans_instance):
|
||
test_params = {
|
||
'scale': [[0, 1, 0.5], [0.1, 0.2, 0.3]],
|
||
'mean': [[0, 0, 128], [125, 127, 128]],
|
||
'reverse_channel': [True, False]
|
||
}
|
||
for scale,mean in zip(test_params['scale'],test_params['mean']):
|
||
for reverse in test_params['reverse_channel']:
|
||
netrans_instance.config(
|
||
scale=scale,
|
||
mean=mean,
|
||
reverse_channel=reverse
|
||
)
|
||
data = netrans_instance._verify_preprocess_value()
|
||
assert data['scale'] == scale
|
||
assert data['mean'] == mean
|
||
assert data['reverse_channel'] == reverse
|
||
|
||
def test_valid_quantize_types(self, netrans_instance):
|
||
test_params = ["uint8", "int8", "int16"]
|
||
for qtype in test_params:
|
||
netrans_instance.quantize(qtype)
|
||
assert netrans_instance.quantize_type == qtype
|
||
|
||
def test_invalid_quantize_type(self, netrans_instance):
|
||
with pytest.raises(TypeError):
|
||
netrans_instance.quantize("float32")
|
||
|
||
def test_quantize(self, netrans_instance):
|
||
netrans_instance.quantize('uint8')
|
||
assert (Path(netrans_instance.model_path) / f"{netrans_instance.model_name}_asymmetric_affine.quantize").exists()
|
||
|
||
def test_export(self, netrans_instance):
|
||
# netrans_instance.quantize('uint8')
|
||
netrans_instance.export(quantize_type='uint8')
|
||
assert (Path(netrans_instance.model_path) / "wksp/asymmetric_affine/network_binary.nb").exists()
|
||
|
||
|
||
# class TestExportMethod:
|
||
# def test_export_flow(self, netrans_instance):
|
||
# with patch.multiple(Netrans,
|
||
# _validate_quant_config=Mock(),
|
||
# _compile_model=Mock()) as mocks:
|
||
# netrans_instance.export()
|
||
# mocks['_validate_quant_config'].assert_called_once()
|
||
# mocks['_compile_model'].assert_called_once()
|
||
|
||
# class TestModel2NBG:
|
||
# @pytest.mark.parametrize("params", [
|
||
# {'quantize_type': 'uint8'},
|
||
# {'quantize_type': 'int8', 'mean': 128, 'scale': 0.0039},
|
||
# {'quantize_type': 'int16', 'mean': [128,127,125], 'scale': 0.0039, 'reverse_channel': True},
|
||
# {'quantize_type': 'uint8', 'inputmeta': True}
|
||
# ])
|
||
# def test_full_workflow(self, netrans_instance, params):
|
||
# with patch.multiple(Netrans,
|
||
# import=Mock(),
|
||
# config=Mock(),
|
||
# quantize=Mock(),
|
||
# export=Mock()) as mocks:
|
||
# netrans_instance.model2nbg(**params)
|
||
|
||
# if 'inputmeta' not in params or params['inputmeta'] is not True:
|
||
# mocks['config'].assert_called_with(
|
||
# mean=params.get('mean'),
|
||
# scale=params.get('scale'),
|
||
# reverse_channel=params.get('reverse_channel'),
|
||
# inputmeta=params.get('inputmeta', False)
|
||
# )
|
||
|
||
# mocks['quantize'].assert_called_with(params['quantize_type'])
|
||
# mocks['export'].assert_called_once()
|
||
|
||
# 代码覆盖率配置(pytest.ini)
|
||
"""
|
||
[pytest]
|
||
addopts = --cov=nertans --cov-report=term-missing
|
||
""" |