forked from nudt_dsp/netrans
62 lines
1.8 KiB
Python
62 lines
1.8 KiB
Python
import json
|
||
import shutil
|
||
from pathlib import Path
|
||
|
||
import pytest
|
||
|
||
from netrans import Netrans
|
||
|
||
|
||
# ------------ 真实路径 fixture ------------
|
||
@pytest.fixture(scope="session")
|
||
def real_model_dir(tmp_path_factory):
|
||
"""把 tests/data/yolov4_tiny 复制到临时目录"""
|
||
here = Path(__file__).resolve()
|
||
src = here.parent.parent.parent / "examples"/ "darknet" / "yolov4_tiny"
|
||
|
||
dst = tmp_path_factory.mktemp("real_model")
|
||
shutil.copytree(src, dst / "yolov4_tiny")
|
||
return dst / "yolov4_tiny"
|
||
|
||
|
||
@pytest.fixture(scope="session")
|
||
def real_netrans():
|
||
"""返回真实 netrans 可执行文件路径;CI 可注入环境变量"""
|
||
path = Path(__file__).resolve().parent.parent.parent / "bin"
|
||
if not path.exists():
|
||
pytest.skip("真实 netrans 可执行文件不存在")
|
||
return str(path)
|
||
|
||
|
||
@pytest.fixture
|
||
def netrans_real(real_model_dir, real_netrans):
|
||
"""返回基于真实模型&真实 netrans 的 Netrans 实例"""
|
||
return Netrans(str(real_model_dir), netrans=real_netrans)
|
||
|
||
|
||
# ------------ 集成测试 ------------
|
||
@pytest.mark.slow
|
||
def test_full_integration(netrans_real):
|
||
"""端到端:load → config → quantize → export"""
|
||
model_path = Path(netrans_real.model_path)
|
||
model_name = netrans_real.model_name
|
||
|
||
# 1. load
|
||
netrans_real.load()
|
||
assert (model_path / f"{model_name}.json").exists()
|
||
assert (model_path / f"{model_name}.data").exists()
|
||
|
||
# 2. config
|
||
netrans_real.config()
|
||
inputmeta_file = model_path / f"{model_name}_inputmeta.yml"
|
||
assert inputmeta_file.exists()
|
||
|
||
# 3. quantize
|
||
netrans_real.quantize("uint8")
|
||
quant_file = model_path / f"{model_name}_asymmetric_affine.quantize"
|
||
assert quant_file.exists()
|
||
|
||
# 4. export
|
||
netrans_real.export(quantize_type="uint8")
|
||
nb_file = model_path / "wksp" / "asymmetric_affine" / "network_binary.nb"
|
||
assert nb_file.exists() |