import argparse
import os

import torch.utils.data
import torchvision.transforms as transforms
from torch import optim
import numpy as np
import torch.nn.functional as F

from img_dataset import Image_DataSet
from models import wideresnet, resnet

from torchvision import models

parser = argparse.ArgumentParser(description='PyTorch DenseNet Training')
parser.add_argument('--model_name', default='resnet-50',
                    help='models: densenet121, resnet-50, resnet-101, inception_v, vgg19_bn')
parser.add_argument('--gpuid', default='0', type=str, help='which gpu to use')
parser.add_argument('--no_cuda', action='store_true', default=False, help='disables CUDA training')
parser.add_argument('-b', '--batch_size', default=64, type=int, help='mini-batch size (default: 64)')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR', help='learning rate')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum')
parser.add_argument('--weight_decay', '--wd', default=0, type=float, metavar='W')
parser.add_argument('--epochs', type=int, default=40, metavar='N', help='number of epochs to train')
parser.add_argument('--schedule', type=int, nargs='+', default=[25, 30, 35],
                    help='Decrease learning rate at these epochs.')
parser.add_argument('--gamma', type=float, default=0.2, help='LR is multiplied by gamma on schedule.')
parser.add_argument('--model_dir', default='dnn_models/tongue_modes/', help='directory of model for saving checkpoint')
parser.add_argument('--save_freq', '-s', default=1, type=int, metavar='N', help='save frequency')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')

args = parser.parse_args()

NUM_CLASSES = 8
use_cuda = not args.no_cuda and torch.cuda.is_available()
GPUID = args.gpuid
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPUID)
device = torch.device("cuda" if use_cuda else "cpu")
torch.manual_seed(1)
# settings
model_dir = args.model_dir
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

torch.backends.cudnn.enabled = False

# training data and test data settings
kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}
normalizer = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                  std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.ToPILImage(),
    transforms.Resize((224)),
    # transforms.RandomCrop((224, 224), padding=14),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalizer
])
processed_dir = './processed-data'
tongue_dataset_train = Image_DataSet(processed_dir=processed_dir, transform=transform_train)
tongue_train_loader = torch.utils.data.DataLoader(tongue_dataset_train, batch_size=args.batch_size, shuffle=True,
                                                  **kwargs)


def get_model(model_name):
    if model_name == 'densenet121':
        return models.densenet121(pretrained=True, num_classes=NUM_CLASSES).to(device)
    elif model_name == 'resnet-50':
        return models.resnet50(pretrained=True, num_classes=NUM_CLASSES).to(device)
    elif model_name == 'vgg19_bn':
        return models.vgg19_bn(pretrained=True, num_classes=NUM_CLASSES).to(device)
    elif model_name == 'inception_v3':
        return models.inception_v3(pretrained=True, num_classes=NUM_CLASSES).to(device)
    else:
        raise ValueError('Unsupport model: {0}', model_name)


def adjust_learning_rate(optimizer, epoch):
    """decrease the learning rate"""
    epoch_lr = args.lr
    for i in range(0, len(args.schedule)):
        if epoch >= args.schedule[i]:
            epoch_lr = args.lr * np.power(args.gamma, (i + 1))
    for param_group in optimizer.param_groups:
        param_group['lr'] = epoch_lr
    return epoch_lr


def train(args, model, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, y) in enumerate(train_loader):
        if use_cuda:
            data, y = data.cuda(non_blocking=True), y.cuda(non_blocking=True)

        optimizer.zero_grad()

        # calculate robust loss
        model.train()
        nat_logits = model(data)
        cur_batch_size = len(y)
        loss = (1.0 / cur_batch_size) * F.cross_entropy(nat_logits, y)

        loss.backward()
        optimizer.step()

        # print progress
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))


def eval_train(model, train_loader):
    model.eval()
    train_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            train_loss += F.cross_entropy(output, target, size_average=False).item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
    train_loss /= len(train_loader.dataset)
    print('Training: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        train_loss, correct, len(train_loader.dataset),
        100. * correct / len(train_loader.dataset)))
    training_accuracy = correct / len(train_loader.dataset)
    return train_loss, training_accuracy


def eval_test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, size_average=False).item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print('Test: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    test_accuracy = correct / len(test_loader.dataset)
    return test_loss, test_accuracy


def main():
    print(args)
    # model = get_model(args.model_name, num_classes=NUM_CLASSES).to(device)
    model = models.resnet50(pretrained=False, num_classes=NUM_CLASSES).to(device)
    # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer = optim.Adam(model.parameters(), args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.weight_decay)

    for epoch in range(1, args.epochs + 1):
        # adjust learning rate for SGD
        adjust_learning_rate(optimizer, epoch)

        # adversarial training
        train(args, model, tongue_train_loader, optimizer, epoch)

        # evaluation on natural examples
        print('================================================================')
        eval_train(model, tongue_train_loader)
        # eval_test(model, test_loader)
        print('================================================================')

        # save checkpoint
        if epoch % args.save_freq == 0:
            torch.save(model.state_dict(),
                       os.path.join(model_dir, 'model-{}-epoch{}.pt'.format(args.model_name, epoch)))
            torch.save(optimizer.state_dict(),
                       os.path.join(model_dir, 'opt-{}-cpt_epoch{}.pt'.format(args.model_name, epoch)))


if __name__ == '__main__':
    main()
