forked from nudt_dsp/netrans
46 lines
1.4 KiB
Python
46 lines
1.4 KiB
Python
"""
|
|
pytest 单文件测试:每个框架一个独立函数
|
|
运行:
|
|
pytest test_model_conversion.py -v
|
|
pytest test_model_conversion.py::test_caffe_conversion
|
|
"""
|
|
import os
|
|
import pytest
|
|
from pathlib import Path
|
|
from netrans import Netrans
|
|
|
|
ROOT = Path(__file__).parent.parent.parent
|
|
|
|
# 通用转换函数
|
|
def _convert(model_dir: Path, mean: float, scale: float):
|
|
model = Netrans(model_path=str(model_dir))
|
|
model.model2nbg(
|
|
quantize_type="uint8",
|
|
mean=mean,
|
|
scale=scale,
|
|
profile=False,
|
|
)
|
|
|
|
# ---------- 各框架独立测试函数 ----------
|
|
def test_caffe_conversion():
|
|
model_dir = ROOT / "examples" / "caffe" / "lenet_caffe"
|
|
_convert(model_dir, mean=0, scale=1.0)
|
|
|
|
def test_darknet_conversion():
|
|
model_dir = ROOT / "examples" / "darknet" / "yolov4_tiny"
|
|
_convert(model_dir, mean=0, scale=1.0)
|
|
|
|
def test_onnx_conversion():
|
|
model_dir = ROOT / "examples" / "onnx" / "yolov5s"
|
|
_convert(model_dir, mean=0, scale=0.003921568627)
|
|
|
|
def test_tensorflow_conversion():
|
|
model_dir = ROOT / "examples" / "tensorflow" / "lenet"
|
|
_convert(model_dir, mean=0, scale=1.0)
|
|
|
|
def test_pytorch_conversion():
|
|
model_dir = ROOT / "examples" / "pytorch" / "resnet50"
|
|
export_script = model_dir / "export_resnet50_2_onnx.py"
|
|
assert export_script.exists(), f"{export_script} not found"
|
|
os.system(f"cd {model_dir} && python export_resnet50_2_onnx.py")
|
|
_convert(model_dir, mean=0, scale=1.0) |