修复生成onnx格式错误的bug

This commit is contained in:
xujiao 2025-07-16 17:46:58 +08:00
parent 5cd1038624
commit 955ed723ca
1 changed files with 20 additions and 4 deletions

View File

@ -1,6 +1,22 @@
import torch
import torchvision.models as models
model = models.resnet50()
input_tensor = torch.rand(1,3,224,224)
model_onnx = torch.jit.trace(model, input_tensor)
model_onnx.save('resnet50.onnx')
# 加载模型
model = models.resnet50(pretrained=True)
model.eval()
# 构造输入
dummy_input = torch.randn(1, 3, 224, 224)
# 导出 ONNX
torch.onnx.export(
model,
dummy_input,
"resnet50.onnx",
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
)