import h5py
from bran.train.nets.svhn import SVHNNet
import tqdm
import torchvision.transforms as transforms
import torchvision
import torch
import numpy as np
import cv2
import bran.shared.image_utils as ImageUtils
import boto3
from tensorboardX import SummaryWriter
from imgaug import augmenters as iaa
from bran.train.trainer import augmenter
from bran.train.mongo_dataset import MongoDataset, MalformedError
from bran.train import AverageMeter
from bran.shared.logger import log
import uuid
import traceback
import time
import sys
import os
import multiprocessing
import math
import logging
import json
import functools
import argparse
from shutil import copyfile
from datetime import datetime
from bran.train.model_eval import evaluate_model
from bran.shared.config import config, get_config
config.SYSLOG_IDENT = 'bran-train'
config.MONGO_DB = get_config('prod').MONGO_DB


parser = argparse.ArgumentParser(description="train")
parser.add_argument("--name", action="store", required=True,
                    help="kill|gamestate|victory|svhn")
parser.add_argument("--job_id", action="store")
# parser.add_argument("--dataset", required=True,
#                     help="Fortnite_kill|Fortnite_victory|Fortnite_gamestate")
# parser.add_argument("--labels", nargs="+", required=True,
#                     help="space seperated list of labels")
parser.add_argument("--model_dir", action="store",
                    default="/var/bran/gamesense/models")
parser.add_argument("--epochs", action="store", type=int, default=30)
parser.add_argument('--workers', action='store', type=int,
                    default=multiprocessing.cpu_count())
parser.add_argument("--batch_size", action="store", type=int, default=128)
parser.add_argument('--lr', '--learning_rate', default=3e-4, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')


def main(args):
    name = args.name
    job_id = args.job_id or '{}_{}'.format(
        name, datetime.now().strftime('%Y%m%d_%H%M%S'))

    learning_rate = args.lr

    output_dir = os.path.join("{}/{}".format(args.model_dir, job_id))
    os.makedirs(output_dir, exist_ok=True)

    writer = SummaryWriter(output_dir)

    # h5f = h5py.File('/home/zihao/Repos/bran/notebooks/data/SVHN_multi_grey.h5', 'r')
    h5f = h5py.File('/home/zihao/Data/SVHN_multi_3_grey.h5', 'r')
    # Extract the datasets

    X_train = h5f['train_dataset'][:].swapaxes(1, 3)
    y_train = h5f['train_labels'][:]
    X_val = h5f['valid_dataset'][:].swapaxes(1, 3)
    y_val = h5f['valid_labels'][:]
    X_test = h5f['test_dataset'][:].swapaxes(1, 3)
    y_test = h5f['test_labels'][:]
    # Close the file
    h5f.close()

    X_train = torch.FloatTensor(X_train)
    y_train = torch.LongTensor(y_train)
    X_val = torch.FloatTensor(X_val)
    y_val = torch.LongTensor(y_val)
    X_test = torch.FloatTensor(X_test)
    y_test = torch.LongTensor(y_test)

    train_set = torch.utils.data.TensorDataset(X_train, y_train)
    val_set = torch.utils.data.TensorDataset(X_val, y_val)
    test_set = torch.utils.data.TensorDataset(X_test, y_test)

    fake_set = torchvision.datasets.FakeData(size=30000, image_size=(1, 32, 32),
                                             transform=lambda x: torch.FloatTensor(np.array(x).reshape((1, 32, 32))), target_transform=lambda x: torch.LongTensor([10, 10, 10]))
    cifar_set = torchvision.datasets.CIFAR10(
        '/home/zihao/Data/cifar/', train=True,
        transform=transforms.Compose([
            lambda x: x.convert('L'),
            transforms.RandomCrop((32, 32)),
            transforms.ToTensor(),
        ]),
        target_transform=lambda x: torch.LongTensor([10, 10, 10]), download=True)

    print(train_set[0][0].shape, train_set[0][1].shape)
    print(fake_set[0][0].shape, fake_set[0][1].shape)
    print(cifar_set[0][0].shape, cifar_set[0][1].shape)

    train_set = torch.utils.data.ConcatDataset([
        train_set,
        fake_set,
        torch.utils.data.Subset(cifar_set, list(range(50000)))
    ])

    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size, num_workers=args.workers, shuffle=True, pin_memory=torch.cuda.is_available())

    val_loader = torch.utils.data.DataLoader(
        val_set, batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=torch.cuda.is_available())

    test_loader = torch.utils.data.DataLoader(
        test_set, batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=torch.cuda.is_available())

    net = SVHNNet()
    writer.add_graph(net, torch.stack([train_set[i][0] for i in range(10)]))

    def criterion(digits_net, digits_in):
        l1 = torch.nn.functional.cross_entropy(digits_net[0], digits_in[:, 0])
        l2 = torch.nn.functional.cross_entropy(digits_net[1], digits_in[:, 1])
        l3 = torch.nn.functional.cross_entropy(digits_net[2], digits_in[:, 2])
        return l1 + l2 + l3

    optimizer = torch.optim.Adam(
        net.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, verbose=True)

    log.info("--------------------------------")
    log.info("pytorch version: {}".format(torch.__version__))
    log.info("pytorch cuda: {}".format(torch.cuda.is_available()))
    log.info("Using Model: {}".format(name))
    log.info("output directory: {}".format(output_dir))
    log.info("train dataset: {} images".format(len(train_set)))
    log.info("val dataset: {} images".format(len(val_set)))
    log.info("test dataset: {} images".format(len(test_set)))
    log.info("total epochs {}".format(args.epochs))
    log.info("--------------------------------------------")

    model = torch.nn.DataParallel(net)
    if torch.cuda.is_available():
        model = model.cuda()

    log.info("created model: {}".format(model))

    def get_num(digits_net):
        _, d1 = torch.max(digits_net[0], dim=1)
        _, d2 = torch.max(digits_net[1], dim=1)
        _, d3 = torch.max(digits_net[2], dim=1)
        return torch.stack([d1, d2, d3], dim=1)

    train_step = 0
    val_step = 0
    best_accuracy = 0
    best_epoch = 0

    data_time = AverageMeter()
    train_time = AverageMeter()

    for epoch in range(args.epochs):
        model.train()

        writer.add_scalar("training/learning_rate",
                          optimizer.param_groups[0]['lr'], epoch)

        running_loss = AverageMeter()
        running_acc = AverageMeter()

        initial_time = time.time()
        it = iter(train_loader)
        for i in tqdm.tqdm(range(len(train_loader))):
            try:
                images, labels = next(it)

                loaded_time = time.time()

                if torch.cuda.is_available():
                    images = images.cuda(non_blocking=True)
                    labels = labels.cuda(non_blocking=True)

                outputs = model(images)
                preds = get_num(outputs)
                loss = criterion(outputs, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                num_corrects = torch.sum(
                    torch.sum(preds == labels, dim=1) == 4)

                running_loss.update(loss.item(), images.size(0))
                running_acc.update(
                    num_corrects.item() / float(images.size(0)), images.size(0))

                trained_time = time.time()

                data_time.update(loaded_time - initial_time)
                train_time.update(trained_time - loaded_time)

                initial_time = time.time()

                writer.add_scalar("training/data_time",
                                  data_time.val, train_step)
                writer.add_scalar("training/train_time",
                                  train_time.val, train_step)
                writer.add_scalar("training/step_loss",
                                  running_loss.val, train_step)
                writer.add_scalar("training/step_acc",
                                  running_acc.val, train_step)
                train_step += 1
            except KeyboardInterrupt:
                raise
            except Exception:
                log.exception("Error processing data")
                raise

        epoch_loss = running_loss.avg
        epoch_acc = running_acc.avg

        log.info('Epoch {} Training: Loss: {:.4f} Acc: {:.4f}'.format(
            epoch, epoch_loss, epoch_acc))

        output_file = "{}/epoch_{}.pt".format(output_dir, epoch)
        log.info("Saving weights: {}".format(output_file))
        torch.save(net.state_dict(), output_file)

        writer.add_scalar("training/loss", epoch_loss, epoch)
        writer.add_scalar("training/accuracy", epoch_acc, epoch)

        model.eval()

        running_loss = AverageMeter()
        running_acc = AverageMeter()
        it = iter(val_loader)
        for i in tqdm.tqdm(range(len(val_loader))):
            try:
                images, labels = next(it)

                if torch.cuda.is_available():
                    images = images.cuda(non_blocking=True)
                    labels = labels.cuda(non_blocking=True)

                outputs = model(images)
                preds = get_num(outputs)
                loss = criterion(outputs, labels)

                num_corrects = torch.sum(
                    torch.sum(preds == labels, dim=1) == 4)

                running_loss.update(loss.item(), images.size(0))
                running_acc.update(
                    num_corrects.item() / float(images.size(0)), images.size(0))

                writer.add_scalar("validation/step_loss",
                                  running_loss.val, val_step)
                writer.add_scalar("validation/step_acc",
                                  running_acc.val, val_step)
                val_step += 1
            except KeyboardInterrupt:
                raise
            except Exception as e:
                log.exception("Error processing data")
                raise

        epoch_loss = running_loss.avg
        epoch_acc = running_acc.avg

        scheduler.step(epoch_loss)

        log.info('Epoch {} Validation: Loss: {:.4f} Acc: {:.4f}'.format(
            epoch, epoch_loss, epoch_acc))

        if epoch_acc > best_accuracy:
            best_accuracy = epoch_acc
            best_epoch = epoch
            log.info("Saving model with best validation accuracy")
            output_file = "{}/weights.pt".format(output_dir)
            torch.save(net.state_dict(), output_file)

        writer.add_scalar("validation/loss", epoch_loss, epoch)
        writer.add_scalar("validation/accuracy", epoch_acc, epoch)

    log.info("Training finished! Best Epoch: {} Acc: {}".format(
        best_epoch, best_accuracy))

    model.eval()
    net.load_state_dict(torch.load("{}/weights.pt".format(output_dir)))

    # unstretched for eval
    # train_set_eval = MongoDataset(
    #     collection=dataset,
    #     split_type="train",
    #     labels=args.labels,
    #     transform=test_transform
    # )

    # train_loader_eval = torch.utils.data.DataLoader(
    #     train_set_eval, batch_size=args.batch_size, num_workers=args.workers, pin_memory=torch.cuda.is_available())

    # train_acc = evaluate_model(
    #     'training', model, train_loader_eval, args.labels, output_dir, writer)
    # val_acc = evaluate_model(
    #     'validation', model, val_loader, args.labels, output_dir, writer)
    # test_acc = evaluate_model(
    #     'test', model, test_loader, args.labels, output_dir, writer)

    # with open(os.path.join(output_dir, 'output.json'), 'w') as f:
    #     json.dump({
    #         'train_acc': train_acc,
    #         'val_acc': val_acc,
    #         'test_acc': test_acc
    #     }, f)

    # log.info("Run: 'python scripts/deploy_detector.py --name {} --job_id {}' to deploy"
    #          .format(name, job_id))

    writer.close()


if __name__ == '__main__':
    main(parser.parse_args())
