Commit 7c0c1b16 authored by zhiyang.zhou's avatar zhiyang.zhou

minor changes

parent 103c8ddd
......@@ -87,7 +87,9 @@ def get_model(model_name):
model.classifier = torch.nn.Linear(in_features=1024, out_features=NUM_CLASSES, bias=True)
return model.to(device)
elif model_name == 'resnet-50':
return models.resnet50(pretrained=True, num_classes=NUM_CLASSES).to(device)
model = models.resnet50(pretrained=True)
model.fc = torch.nn.Linear(in_features=2048, out_features=NUM_CLASSES, bias=True)
return model.to(device)
elif model_name == 'vgg19_bn':
model = models.vgg19_bn(pretrained=True)
model.classifier[6] = torch.nn.Linear(in_features=4096, out_features=NUM_CLASSES, bias=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment