Commit fcc44504 authored by zhiyang.zhou's avatar zhiyang.zhou

minor changes

parent 5bc3eadd
......@@ -4,19 +4,19 @@ import os
import torch.utils.data
import torchvision.transforms as transforms
from torch import optim
from torchvision import datasets
import numpy as np
import torch.nn.functional as F
from img_dataset import Image_DataSet
from models import wideresnet, resnet
parser = argparse.ArgumentParser(description='PyTorch DenseNet Training')
parser.add_argument('--model_name', default='wide_resnet-34x10',
parser.add_argument('--model_name', default='resnet-50',
help='model name, cifar-10 models: wide_resnet-34x10, cifar-100 models: resnet-50, resnet-101,'
'resnet-152')
parser.add_argument('--gpuid', default='0', type=str, help='which gpu to use')
parser.add_argument('--no_cuda', action='store_true', default=False, help='disables CUDA training')
parser.add_argument('-b', '--batch_size', default=64, type=int, help='mini-batch size (default: 64)')
parser.add_argument('-b', '--batch_size', default=2, type=int, help='mini-batch size (default: 64)')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum')
parser.add_argument('--weight_decay', '--wd', default=2e-4, type=float, metavar='W')
......@@ -46,8 +46,8 @@ normalizer = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]
transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.ToPILImage(),
transforms.RandomResizedCrop((224, 168)),
transforms.RandomCrop((224, 168), padding=14),
transforms.RandomResizedCrop((224, 224)),
# transforms.RandomCrop((224, 224), padding=14),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# normalizer
......@@ -93,7 +93,7 @@ def train(args, model, train_loader, optimizer, epoch):
# calculate robust loss
model.train()
nat_logits = model(data)
cur_batch_size = len(data)
cur_batch_size = len(y)
loss = (1.0 / cur_batch_size) * F.cross_entropy(nat_logits, y)
loss.backward()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment