forked from nudt_dsp/netrans
修复生成onnx格式错误的bug
This commit is contained in:
parent
5cd1038624
commit
955ed723ca
|
@ -1,6 +1,22 @@
|
|||
import torch
|
||||
import torchvision.models as models
|
||||
model = models.resnet50()
|
||||
input_tensor = torch.rand(1,3,224,224)
|
||||
model_onnx = torch.jit.trace(model, input_tensor)
|
||||
model_onnx.save('resnet50.onnx')
|
||||
|
||||
# 加载模型
|
||||
model = models.resnet50(pretrained=True)
|
||||
model.eval()
|
||||
|
||||
# 构造输入
|
||||
dummy_input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
# 导出 ONNX
|
||||
torch.onnx.export(
|
||||
model,
|
||||
dummy_input,
|
||||
"resnet50.onnx",
|
||||
export_params=True,
|
||||
opset_version=11,
|
||||
do_constant_folding=True,
|
||||
input_names=["input"],
|
||||
output_names=["output"],
|
||||
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
|
||||
)
|
Loading…
Reference in New Issue