netrans/test/netrans_py/test_netrans.py

209 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pytest
from pathlib import Path
from unittest.mock import Mock, patch
from netrans import Netrans
import shutil
@pytest.fixture
def mock_model_dir(tmp_path):
"""创建复制真实示例目录内容的虚拟模型目录"""
# 获取示例目录的绝对路径
current_file = Path(__file__)
examples_dir = current_file.parent.parent.parent / "examples" # 调整路径层级
source_dir = examples_dir / "darknet/yolov4_tiny"
# 验证源目录存在
if not source_dir.exists():
raise FileNotFoundError(f"示例目录未找到: {source_dir}")
# 创建目标目录路径
model_dir = tmp_path / "yolov4_tiny"
# 执行目录复制
shutil.copytree(
src=str(source_dir),
dst=str(model_dir),
symlinks=True,
ignore_dangling_symlinks=True
)
return model_dir
@pytest.fixture
def mock_netrans_path(tmp_path):
"""创建复制真实示例目录内容的虚拟模型目录"""
# 获取示例目录的绝对路径
current_file = Path(__file__)
source_dir = current_file.parent.parent.parent / "bin" # 调整路径层级
# 创建目标目录路径
netrans_path = tmp_path / "bin"
# 执行目录复制
shutil.copytree(
src=str(source_dir),
dst=str(netrans_path),
symlinks=True,
ignore_dangling_symlinks=True
)
return netrans_path
@pytest.fixture
def netrans_instance(mock_model_dir, mock_netrans_path):
return Netrans(str(mock_model_dir), netrans=str(mock_netrans_path))
class TestNetransInitialization:
def test_init_with_default_netrans(self, mock_model_dir):
# 测试能否找到默认的netrans(.bashrc NETRANS_PATH)
net = Netrans(str(mock_model_dir))
assert net.model_path == str(mock_model_dir)
assert hasattr(net, 'netrans_path')
def test_init_with_custom_netrans(self, mock_model_dir, mock_netrans_path):
# 测试传参 netrans
net = Netrans(str(mock_model_dir), netrans= str(mock_netrans_path))
assert net.netrans_path == str(mock_netrans_path)
def test_init_with_invalid_model_path(self):
# 测试给定非法路径报错
with pytest.raises(FileNotFoundError):
Netrans("invalid/path")
def test_load_success(self, netrans_instance, tmp_path):
# 测试模型导入功能
with patch('subprocess.run') as mock_run:
netrans_instance.load()
mock_run.assert_called_once()
# 验证生成的文件
output_dir = Path(netrans_instance.model_path)
assert (output_dir / f"{netrans_instance.model_name}.json").exists()
assert (output_dir / f"{netrans_instance.model_name}.data").exists()
def test_load_failure(self, netrans_instance):
with patch('subprocess.run', side_effect=Exception("Process error")):
with pytest.raises(RuntimeError):
netrans_instance.load()
# @pytest.mark.parametrize("inputmeta_param, expected", [
# (False, "generate_template"),
# (True, "auto_detect"),
# (netrans_instance.model_path/f"{netrans_instance.model_name}_inputmeta.yml", "use_custom")
# ])
# def test_config_gen(self, netrans_instance, inputmeta_param, expected):
# with patch.object(Netrans, '_handle_inputmeta') as mock_method:
# netrans_instance.config(inputmeta=inputmeta_param)
# mock_method.assert_called_with(expected)
def test_config_gen_inputmeta(self, netrans_instance):
with patch('subprocess.run') as mock_run:
netrans_instance.config()
mock_run.assert_called_once()
assert (Path(netrans_instance.model_path) / f"{netrans_instance.model_name}_inputmeta.yml").exists()
def test_config_auto_find_inputmeta(self, netrans_instance):
with patch('subprocess.run') as mock_run:
netrans_instance.config(True)
assert netrans_instance.input_meta == str(Path(netrans_instance.model_path) / f"{netrans_instance.model_name}_inputmeta.yml")
def test_config_use_costume_inputmeta(self, netrans_instance):
inputmeta = str(Path(netrans_instance.model_path) / f"{netrans_instance.model_name}_inputmeta.yml")
cp_file = inputmeta+'.tmp.yml'
shutil.copy2(str(inputmeta), str(cp_file))
netrans_instance.config(inputmeta=cp_file)
assert netrans_instance.input_meta == cp_file
def test_config_with_invalid_model_path(self, netrans_instance):
# with patch('subprocess.run', side_effect=Exception("Process error")):
# with pytest.raises(RuntimeError):
# netrans_instance.config(False)
# 测试给定非法路径报错
with pytest.raises(FileExistsError):
inputmeta="invalid/path"
print(isinstance(inputmeta, str))
netrans_instance.config(inputmeta="invalid/path")
def test_config_parameter_combinations(self, netrans_instance):
test_params = {
'scale': [[0, 1, 0.5], [0.1, 0.2, 0.3]],
'mean': [[0, 0, 128], [125, 127, 128]],
'reverse_channel': [True, False]
}
for scale,mean in zip(test_params['scale'],test_params['mean']):
for reverse in test_params['reverse_channel']:
netrans_instance.config(
scale=scale,
mean=mean,
reverse_channel=reverse
)
data = netrans_instance._verify_preprocess_value()
assert data['scale'] == scale
assert data['mean'] == mean
assert data['reverse_channel'] == reverse
def test_valid_quantize_types(self, netrans_instance):
test_params = ["uint8", "int8", "int16"]
for qtype in test_params:
netrans_instance.quantize(qtype)
assert netrans_instance.quantize_type == qtype
def test_invalid_quantize_type(self, netrans_instance):
with pytest.raises(TypeError):
netrans_instance.quantize("float32")
def test_quantize(self, netrans_instance):
netrans_instance.quantize('uint8')
assert (Path(netrans_instance.model_path) / f"{netrans_instance.model_name}_asymmetric_affine.quantize").exists()
def test_export(self, netrans_instance):
# netrans_instance.quantize('uint8')
netrans_instance.export(quantize_type='uint8')
assert (Path(netrans_instance.model_path) / "wksp/asymmetric_affine/network_binary.nb").exists()
# class TestExportMethod:
# def test_export_flow(self, netrans_instance):
# with patch.multiple(Netrans,
# _validate_quant_config=Mock(),
# _compile_model=Mock()) as mocks:
# netrans_instance.export()
# mocks['_validate_quant_config'].assert_called_once()
# mocks['_compile_model'].assert_called_once()
# class TestModel2NBG:
# @pytest.mark.parametrize("params", [
# {'quantize_type': 'uint8'},
# {'quantize_type': 'int8', 'mean': 128, 'scale': 0.0039},
# {'quantize_type': 'int16', 'mean': [128,127,125], 'scale': 0.0039, 'reverse_channel': True},
# {'quantize_type': 'uint8', 'inputmeta': True}
# ])
# def test_full_workflow(self, netrans_instance, params):
# with patch.multiple(Netrans,
# import=Mock(),
# config=Mock(),
# quantize=Mock(),
# export=Mock()) as mocks:
# netrans_instance.model2nbg(**params)
# if 'inputmeta' not in params or params['inputmeta'] is not True:
# mocks['config'].assert_called_with(
# mean=params.get('mean'),
# scale=params.get('scale'),
# reverse_channel=params.get('reverse_channel'),
# inputmeta=params.get('inputmeta', False)
# )
# mocks['quantize'].assert_called_with(params['quantize_type'])
# mocks['export'].assert_called_once()
# 代码覆盖率配置pytest.ini
"""
[pytest]
addopts = --cov=nertans --cov-report=term-missing
"""