import argparse
import os

import torch.utils.data
import torchvision.transforms as transforms
from torch import optim
import numpy as np
import torch.nn.functional as F

from img_dataset import Image_DataSet
# from torch.autograd import Variable
# from graphviz import Digraph

from torchvision import models

parser = argparse.ArgumentParser(description='PyTorch Tongue-diagnosis Training')
parser.add_argument('--model_name', default='resnet-50',
                    help='models: densenet121, resnet-50, vgg19_bn')
parser.add_argument('--gpuid', default='0', type=str, help='which gpu to use')
parser.add_argument('--no_cuda', action='store_true', default=False, help='disables CUDA training')
parser.add_argument('-b', '--batch_size', default=64, type=int, help='mini-batch size (default: 64)')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR', help='learning rate')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum')
parser.add_argument('--weight_decay', '--wd', default=0, type=float, metavar='W')
parser.add_argument('--epochs', type=int, default=40, metavar='N', help='number of epochs to train')
parser.add_argument('--schedule', type=int, nargs='+', default=[10, 20, 30],
                    help='Decrease learning rate at these epochs.')
parser.add_argument('--gamma', type=float, default=0.5, help='LR is multiplied by gamma on schedule.')
parser.add_argument('--model_dir', default='dnn_models/tongue_modes/shetai/', help='directory of model for saving checkpoint')
parser.add_argument('--training_data_dir', default='./processed-data-shetai', help='directory of training data')
parser.add_argument('--test_data_dir', default='./shetai-test', help='directory of test data')
parser.add_argument('--save_freq', '-s', default=1, type=int, metavar='N', help='save frequency')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
parser.add_argument('--pretrain', action='store_true', default=False,
                    help='whether to use pretrain')

args = parser.parse_args()

NUM_CLASSES = 8
# GPU settings
GPUID = args.gpuid
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPUID)
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")

# model-saving dir setting
model_dir = args.model_dir
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

torch.backends.cudnn.enabled = False

# training data and test data settings
kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}
normalizer = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                  std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
#trining data
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.ToPILImage(),
    transforms.Resize((224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    # normalizer
])
training_data_dir = args.training_data_dir
tongue_dataset_train = Image_DataSet(processed_dir=training_data_dir, transform=transform_train)
tongue_train_loader = torch.utils.data.DataLoader(tongue_dataset_train, batch_size=args.batch_size, shuffle=True,
                                                  **kwargs)
# test data
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.ToPILImage(),
    transforms.Resize((298, 224)),
    transforms.ToTensor(),
    normalizer
])
test_data_dir = args.test_data_dir
tongue_dataset_test = Image_DataSet(processed_dir=test_data_dir, transform=transform_test)
tongue_test_loader = torch.utils.data.DataLoader(tongue_dataset_test, batch_size=args.batch_size, shuffle=False,
                                                  **kwargs)


def get_model(model_name):
    if model_name == 'densenet121':
        model = models.densenet121(pretrained=args.pretrain)
        model.classifier = torch.nn.Linear(in_features=1024, out_features=NUM_CLASSES, bias=True)
        return model.to(device)
    elif model_name == 'resnet-50':
        model = models.resnet50(pretrained=args.pretrain)
        model.fc = torch.nn.Linear(in_features=2048, out_features=NUM_CLASSES, bias=True)
        return model.to(device)
    elif model_name == 'vgg19_bn':
        model = models.vgg19_bn(pretrained=args.pretrain)
        model.classifier[6] = torch.nn.Linear(in_features=4096, out_features=NUM_CLASSES, bias=True)
        return model.to(device)
    else:
        raise ValueError('Unsupport model: {0}', model_name)


def adjust_learning_rate(optimizer, epoch):
    """decrease the learning rate"""
    epoch_lr = args.lr
    for i in range(0, len(args.schedule)):
        if epoch >= args.schedule[i]:
            epoch_lr = args.lr * np.power(args.gamma, (i + 1))
    for param_group in optimizer.param_groups:
        param_group['lr'] = epoch_lr
    return epoch_lr


def train(args, model, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, y) in enumerate(train_loader):
        if use_cuda:
            data, y = data.cuda(non_blocking=True), y.cuda(non_blocking=True)

        optimizer.zero_grad()

        # calculate robust loss
        model.train()
        nat_logits = model(data)
        cur_batch_size = len(y)
        loss = (1.0 / cur_batch_size) * F.cross_entropy(nat_logits, y)

        loss.backward()
        optimizer.step()

        # print progress
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))


def eval_train(model, train_loader):
    model.eval()
    train_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            train_loss += F.cross_entropy(output, target, size_average=False).item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
    train_loss /= len(train_loader.dataset)
    print('Training: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        train_loss, correct, len(train_loader.dataset),
        100. * correct / len(train_loader.dataset)))
    training_accuracy = correct / len(train_loader.dataset)
    return train_loss, training_accuracy


def eval_test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, size_average=False).item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print('Test: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    test_accuracy = correct / len(test_loader.dataset)
    return test_loss, test_accuracy


def make_dot(var, params=None):
    """ Produces Graphviz representation of PyTorch autograd graph
    Blue nodes are the Variables that require grad, orange are Tensors
    saved for backward in torch.autograd.Function
    Args:
        var: output Variable
        params: dict of (name, Variable) to add names to node that
            require grad (TODO: make optional)
    """
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}

    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()

    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'

    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot


def print_model_arch(model):
    x = torch.randn(1, 3, 298, 224).to(device)
    y = model(x)
    g = make_dot(y)
    g.view()

    # from torchsummary import summary
    # summary(model, (3, 224, 224))

    params = list(model.parameters())
    k = 0
    for i in params:
        l = 1
        print("该层的结构：" + str(list(i.size())))
        for j in i.size():
            l *= j
        print("该层参数和：" + str(l))
        k = k + l
    print("总参数数量和：" + str(k))


def main():
    print(args)
    model = get_model(args.model_name).to(device)
    print(model)
    # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer = optim.Adam(model.parameters(), args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.weight_decay)

    for epoch in range(1, args.epochs + 1):
        # adjust learning rate for SGD
        adjust_learning_rate(optimizer, epoch)

        # adversarial training
        train(args, model, tongue_train_loader, optimizer, epoch)

        # evaluation on natural examples
        print('================================================================')
        eval_train(model, tongue_train_loader)
        eval_test(model, tongue_test_loader)
        print('================================================================')

        # save checkpoint
        if epoch % args.save_freq == 0:
            torch.save(model.state_dict(),
                       os.path.join(model_dir, 'model-{}-epoch{}.pt'.format(args.model_name, epoch)))
            torch.save(optimizer.state_dict(),
                       os.path.join(model_dir, 'opt-{}-cpt_epoch{}.pt'.format(args.model_name, epoch)))


if __name__ == '__main__':
    main()
