Commit a70a18ea authored by zhiyang.zhou's avatar zhiyang.zhou

支持对其他DNN进行迁移学习

parent 623d57c6
......@@ -14,7 +14,7 @@ from torchvision import models
parser = argparse.ArgumentParser(description='PyTorch DenseNet Training')
parser.add_argument('--model_name', default='resnet-50',
help='models: densenet121, resnet-50, resnet-101, inception_v, vgg19_bn')
help='models: densenet121, resnet-50, resnet-101, inception_v3, vgg19_bn')
parser.add_argument('--gpuid', default='0', type=str, help='which gpu to use')
parser.add_argument('--no_cuda', action='store_true', default=False, help='disables CUDA training')
parser.add_argument('-b', '--batch_size', default=64, type=int, help='mini-batch size (default: 64)')
......@@ -29,6 +29,7 @@ parser.add_argument('--model_dir', default='dnn_models/tongue_modes/', help='dir
parser.add_argument('--save_freq', '-s', default=1, type=int, metavar='N', help='save frequency')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
args = parser.parse_args()
......@@ -68,13 +69,19 @@ tongue_train_loader = torch.utils.data.DataLoader(tongue_dataset_train, batch_si
def get_model(model_name):
if model_name == 'densenet121':
return models.densenet121(pretrained=True, num_classes=NUM_CLASSES).to(device)
model = models.densenet121(pretrained=True)
model.classifier = torch.nn.Linear(in_features=1024, out_features=NUM_CLASSES, bias=True)
return model.to(device)
elif model_name == 'resnet-50':
return models.resnet50(pretrained=True, num_classes=NUM_CLASSES).to(device)
elif model_name == 'vgg19_bn':
return models.vgg19_bn(pretrained=True, num_classes=NUM_CLASSES).to(device)
model = models.vgg19_bn(pretrained=True)
model.classifier[6] = torch.nn.Linear(in_features=4096, out_features=NUM_CLASSES, bias=True)
return models.to(device)
elif model_name == 'inception_v3':
return models.inception_v3(pretrained=True, num_classes=NUM_CLASSES).to(device)
model = models.inception_v3(pretrained=True)
model.fc = torch.nn.Linear(in_features=2048, out_features=1000, bias=True)
return models.to(device)
else:
raise ValueError('Unsupport model: {0}', model_name)
......@@ -154,8 +161,8 @@ def eval_test(model, test_loader):
def main():
print(args)
# model = get_model(args.model_name, num_classes=NUM_CLASSES).to(device)
model = models.resnet50(pretrained=False, num_classes=NUM_CLASSES).to(device)
model = get_model(args.model_name).to(device)
print(model)
# optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
optimizer = optim.Adam(model.parameters(), args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.weight_decay)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment