Commit ca355dee authored by zhiyang.zhou's avatar zhiyang.zhou

use pretrained models

parent aa8f9999
Pipeline #191 canceled with stages
......@@ -20,7 +20,7 @@ class Image_DataSet(torch.utils.data.Dataset):
index = index % (len(imgs) - 1)
img_file = os.path.join(self.processed_dir, imgs[index])
source_image = PIL.Image.open(img_file)
source_image = PIL.Image.open(img_file).convert('RGB')
# source_image=np.array(source_image)
# image = np.expand_dims(source_image.transpose((2, 0, 1)), 0)
if self.transform is not None:
......
......@@ -10,10 +10,11 @@ import torch.nn.functional as F
from img_dataset import Image_DataSet
from models import wideresnet, resnet
from torchvision import models
parser = argparse.ArgumentParser(description='PyTorch DenseNet Training')
parser.add_argument('--model_name', default='resnet-50',
help='model name, cifar-10 models: wide_resnet-34x10, cifar-100 models: resnet-50, resnet-101,'
'resnet-152')
help='models: densenet121, resnet-50, resnet-101, inception_v, vgg19_bn')
parser.add_argument('--gpuid', default='0', type=str, help='which gpu to use')
parser.add_argument('--no_cuda', action='store_true', default=False, help='disables CUDA training')
parser.add_argument('-b', '--batch_size', default=32, type=int, help='mini-batch size (default: 32)')
......@@ -58,15 +59,15 @@ tongue_train_loader = torch.utils.data.DataLoader(tongue_dataset_train, batch_si
**kwargs)
def get_model(model_name, num_classes=NUM_CLASSES):
if model_name == 'wide_resnet-34x10':
return wideresnet.WideResNet(num_classes=num_classes)
def get_model(model_name):
if model_name == 'densenet121':
return models.densenet121(pretrained=True, num_classes=NUM_CLASSES).to(device)
elif model_name == 'resnet-50':
return resnet.ResNet50(num_classes=num_classes)
elif model_name == 'resnet-101':
return resnet.ResNet101(num_classes=num_classes)
elif model_name == 'resnet-152':
return resnet.ResNet152(num_classes=num_classes)
return models.resnet50(pretrained=True, num_classes=NUM_CLASSES).to(device)
elif model_name == 'vgg19_bn':
return models.vgg19_bn(pretrained=True, num_classes=NUM_CLASSES).to(device)
elif model_name == 'inception_v3':
return models.inception_v3(pretrained=True, num_classes=NUM_CLASSES).to(device)
else:
raise ValueError('Unsupport model: {0}', model_name)
......@@ -146,7 +147,8 @@ def eval_test(model, test_loader):
def main():
print(args)
model = get_model(args.model_name, num_classes=NUM_CLASSES).to(device)
# model = get_model(args.model_name, num_classes=NUM_CLASSES).to(device)
model = models.resnet50(pretrained=False, num_classes=NUM_CLASSES).to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
for epoch in range(1, args.epochs + 1):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment