import sys
import torch
import torch.onnx
import torchvision

net = torchvision.models.resnet18()
num_features = net.fc.in_features
net.fc = torch.nn.Linear(num_features, 2)
net.load_state_dict(torch.load(sys.argv[1]))

x_in = torch.Tensor(torch.randn(1, 3, 224, 224))
model = net

torch.onnx.export(model, x_in, "out.onnx", input_names=['input'], output_names=['output'])
