Commit 96e31bcc authored by zhiyang.zhou's avatar zhiyang.zhou

minor changes

parent a70a18ea
......@@ -77,11 +77,11 @@ def get_model(model_name):
elif model_name == 'vgg19_bn':
model = models.vgg19_bn(pretrained=True)
model.classifier[6] = torch.nn.Linear(in_features=4096, out_features=NUM_CLASSES, bias=True)
return models.to(device)
return model.to(device)
elif model_name == 'inception_v3':
model = models.inception_v3(pretrained=True)
model.fc = torch.nn.Linear(in_features=2048, out_features=1000, bias=True)
return models.to(device)
return model.to(device)
else:
raise ValueError('Unsupport model: {0}', model_name)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment