import h5py
from bran.train.nets.svhn import SVHNNet
from bran.train.nets.svhn_res_attn import SVHNResidualAttention
import tqdm
import torchvision.transforms as transforms
import torchvision
import torch
import numpy as np
import cv2
import bran.shared.image_utils as ImageUtils
import boto3
from tensorboardX import SummaryWriter
from imgaug import augmenters as iaa
from bran.train.trainer import augmenter
from bran.train.mongo_dataset import MongoDataset, MalformedError
from bran.train.mongo_collection import MongoCollection
from bran.train import AverageMeter
from bran.shared.logger import log
import uuid
import traceback
import time
import sys
import os
import multiprocessing
import math
import logging
import json
import functools
import argparse
from shutil import copyfile
from datetime import datetime
from bran.train.model_eval import evaluate_model
from bran.shared.config import config, get_config
from bran.train.trainer import augmenter
config.SYSLOG_IDENT = 'bran-train'
config.MONGO_DB = get_config('dev').MONGO_DB

NUM_SIZE = 3

NAME_TO_CROP = {
    "apex-kill": {"x": 0.819, "y": 0.047, "width": 0.034, "height": 0.031},
    "pubg-kill": {"x": 0.859, "y": 0.03, "width": 0.048, "height": 0.035},
    "pubgmobile-kill": {"x": 0.123, "y": 0.005, "width": 0.022, "height": 0.045},
}

parser = argparse.ArgumentParser(description="train")
parser.add_argument("--name", action="store", required=True,
                    help="kill|gamestate|victory|svhn")
parser.add_argument("--job_id", action="store")
parser.add_argument("--dataset", required=True,
                    help="Fortnite_kill|Fortnite_victory|Fortnite_gamestate")
parser.add_argument("--labels", nargs="+", required=False,
                    help="space seperated list of labels")
parser.add_argument("--model_dir", action="store",
                    default="/var/bran/gamesense/models")
parser.add_argument("--epochs", action="store", type=int, default=30)
parser.add_argument('--workers', action='store', type=int,
                    default=multiprocessing.cpu_count())
parser.add_argument("--batch_size", action="store", type=int, default=128)
parser.add_argument('--lr', '--learning_rate', default=3e-4, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--model', default='svhn')


def num_to_target(n, size=NUM_SIZE):
    arr = np.full(size, 10)
    if n == -1:
        return arr
    lst = []
    for i in range(size):
        lst.append(n % 10)
        n //= 10
        if not n:
            break
    for i, n in enumerate(reversed(lst)):
        arr[i] = n
    return arr


def get_num(digits_net):
    ds = []
    for i in range(NUM_SIZE):
        ds.append(torch.max(digits_net[i], dim=1)[1])
    return torch.stack(ds, dim=1)


def evaluate(name, model, dataset, output_dir, args, writer):
    log.info("Evaluating model with {} set".format(name))
    dataset_loader = torch.utils.data.DataLoader(
        dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=torch.cuda.is_available())

    all_preds = []
    all_labels = []

    model.eval()
    it = iter(dataset_loader)
    for i in tqdm.tqdm(range(len(dataset_loader))):
        try:
            images, labels = next(it)

            if torch.cuda.is_available():
                images = images.cuda(non_blocking=True)
                labels = labels.cuda(non_blocking=True)

            outputs = model(images)
            preds = get_num(outputs)

            predicted = preds.cpu().detach().numpy().astype(np.uint8)
            labels = labels.cpu().detach().numpy().astype(np.uint8)

            all_preds.append(predicted)
            all_labels.append(labels)
        except KeyboardInterrupt:
            raise
        except Exception as e:
            log.exception("Error processing data")
            raise

    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)

    log.info("Total data: {}".format(len(dataset_loader)))

    # Plot normalized confusion matrix
    # cm = confusion_matrix(all_labels, all_preds)
    # np.set_printoptions(precision=2)
    # cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    # log.info("Confusion Matrix: {}".format(cm))

    # fig = plot_confusion_matrix(cm, classes)
    # writer.add_figure("{}/confusion_matrix".format(name), fig)

    accurarcy_perc = 100 * (np.sum(
        np.sum(all_preds == all_labels, axis=1) == NUM_SIZE) / len(all_labels))

    log.info('Accuracy: {:.2f}%'.format(accurarcy_perc))

    return accurarcy_perc


def main(args):
    name = args.name
    job_id = args.job_id or '{}_{}'.format(
        name, datetime.now().strftime('%Y%m%d_%H%M%S'))

    learning_rate = args.lr

    output_dir = os.path.join("{}/{}".format(args.model_dir, job_id))
    os.makedirs(output_dir, exist_ok=True)

    writer = SummaryWriter(output_dir)

    # h5f = h5py.File('/home/zihao/Repos/bran/notebooks/data/SVHN_multi_grey.h5', 'r')
    h5f = h5py.File('/home/zihao/Data/SVHN_multi_3_grey.h5', 'r')
    # Extract the datasets

    X_train = h5f['train_dataset'][:].transpose((0, 3, 1, 2)) / 255.0
    y_train = h5f['train_labels'][:]
    X_val = h5f['valid_dataset'][:].transpose((0, 3, 1, 2)) / 255.0
    y_val = h5f['valid_labels'][:]
    X_test = h5f['test_dataset'][:].transpose((0, 3, 1, 2)) / 255.0
    y_test = h5f['test_labels'][:]
    # Close the file
    h5f.close()

    X_train = torch.FloatTensor(X_train)
    y_train = torch.LongTensor(y_train)
    X_val = torch.FloatTensor(X_val)
    y_val = torch.LongTensor(y_val)
    X_test = torch.FloatTensor(X_test)
    y_test = torch.LongTensor(y_test)

    svhn_train_set = torch.utils.data.TensorDataset(X_train, y_train)
    svhn_val_set = torch.utils.data.TensorDataset(X_val, y_val)
    svhn_test_set = torch.utils.data.TensorDataset(X_test, y_test)

    fake_train_set = torchvision.datasets.FakeData(
        size=200, image_size=(32, 32),
        transform=transforms.ToTensor(),
        target_transform=lambda x: torch.LongTensor([10, 10, 10]))
    fake_val_set = torchvision.datasets.FakeData(
        size=2000, image_size=(32, 32),
        transform=transforms.ToTensor(),
        target_transform=lambda x: torch.LongTensor([10, 10, 10]))
    fake_test_set = torchvision.datasets.FakeData(
        size=2000, image_size=(32, 32),
        transform=transforms.ToTensor(),
        target_transform=lambda x: torch.LongTensor([10, 10, 10]))

    cifar_train_transform = transforms.Compose([
        transforms.RandomCrop((32, 32), padding=4, padding_mode='reflect'),
        transforms.RandomHorizontalFlip(),
        lambda x: x.convert('L'),
        transforms.ToTensor(),
    ])
    cifar_set = torchvision.datasets.CIFAR10(
        '/home/zihao/Data/cifar/', train=True,
        transform=cifar_train_transform,
        target_transform=lambda x: torch.LongTensor([10, 10, 10]), download=True)
    cifar_train_val_split = int(len(cifar_set) * 0.9)
    cifar_train_idx = list(range(len(cifar_set)))[:cifar_train_val_split]
    cifar_val_idx = list(range(len(cifar_set)))[cifar_train_val_split:]
    cifar_train_set = torch.utils.data.Subset(cifar_set, cifar_train_idx)
    cifar_val_set = torch.utils.data.Subset(cifar_set, cifar_val_idx)
    cifar_test_set = torchvision.datasets.CIFAR10(
        '/home/zihao/Data/cifar/', train=False,
        transform=transforms.Compose([
            lambda x: x.convert('L'),
            transforms.ToTensor(),
        ]),
        target_transform=lambda x: torch.LongTensor([10, 10, 10]))

    yeller_transform = transforms.Compose([
        lambda img: ImageUtils.crop(img, NAME_TO_CROP[args.name]),
        lambda img: cv2.resize(img, (32, 32)),
        augmenter,
        lambda img: cv2.cvtColor(img, cv2.COLOR_BGR2GRAY),
        lambda img: np.reshape(img, (32, 32, 1)),
        transforms.ToTensor()
    ])

    yeller_train_set = MongoCollection(
        collection=args.dataset,
        query={"label": {"$in": ['no', 'yes']}, "split_type": 'train'},
        transform=yeller_transform,
        target_transform=transforms.Compose([
            lambda record: int(
                record['extra']) if record['label'] == 'yes' else -1,
            lambda num: torch.LongTensor(num_to_target(num))
        ])
    )

    yeller_val_set = MongoCollection(
        collection=args.dataset,
        query={"label": {"$in": ['no', 'yes']}, "split_type": 'val'},
        transform=yeller_transform,
        target_transform=transforms.Compose([
            lambda record: int(
                record['extra']) if record['label'] == 'yes' else -1,
            lambda num: torch.LongTensor(num_to_target(num))
        ])
    )

    yeller_test_set = MongoCollection(
        collection=args.dataset,
        query={"label": {"$in": ['no', 'yes']}, "split_type": 'test'},
        transform=yeller_transform,
        target_transform=transforms.Compose([
            lambda record: int(
                record['extra']) if record['label'] == 'yes' else -1,
            lambda num: torch.LongTensor(num_to_target(num))
        ])
    )

    train_set = torch.utils.data.ConcatDataset([
        #    svhn_train_set,
        fake_train_set,
        # cifar_train_set,
        yeller_train_set,
    ])

    val_set = torch.utils.data.ConcatDataset([
        #    svhn_val_set,
        fake_val_set,
        cifar_val_set,
        yeller_val_set,
    ])

    test_set = torch.utils.data.ConcatDataset([
        svhn_test_set,
        fake_test_set,
        cifar_test_set,
        yeller_test_set,
    ])

    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size, num_workers=args.workers, shuffle=True, pin_memory=torch.cuda.is_available())

    val_loader = torch.utils.data.DataLoader(
        val_set, batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=torch.cuda.is_available())

    test_loader = torch.utils.data.DataLoader(
        test_set, batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=torch.cuda.is_available())

    if args.model == 'svhn':
        net = SVHNNet()
    elif args.model == 'svhn_res_attn':
        net = SVHNResidualAttention()  # L U L

    writer.add_graph(net, torch.stack([train_set[i][0] for i in range(10)]))

    def criterion(digits_net, digits_in):
        l1 = torch.nn.functional.cross_entropy(digits_net[0], digits_in[:, 0])
        l2 = torch.nn.functional.cross_entropy(digits_net[1], digits_in[:, 1])
        l3 = torch.nn.functional.cross_entropy(digits_net[2], digits_in[:, 2])
        return l1 + l2 + l3

    optimizer = torch.optim.Adam(
        net.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, verbose=True)

    log.info("--------------------------------")
    log.info("pytorch version: {}".format(torch.__version__))
    log.info("pytorch cuda: {}".format(torch.cuda.is_available()))
    log.info("Using Model: {}".format(name))
    log.info("output directory: {}".format(output_dir))
    log.info("train dataset: {} images".format(len(train_set)))
    log.info("val dataset: {} images".format(len(val_set)))
    log.info("test dataset: {} images".format(len(test_set)))
    log.info("total epochs {}".format(args.epochs))
    log.info("--------------------------------------------")

    net.load_state_dict(torch.load('data/apex-kill-pretrained.pt'))
    model = torch.nn.DataParallel(net)
    if torch.cuda.is_available():
        model = model.cuda()

    log.info("created model: {}".format(model))

    train_step = 0
    val_step = 0
    best_accuracy = 0
    best_loss = 0
    best_epoch = 0

    data_time = AverageMeter()
    train_time = AverageMeter()

    for epoch in range(args.epochs):
        model.train()

        writer.add_scalar("training/learning_rate",
                          optimizer.param_groups[0]['lr'], epoch)

        running_loss = AverageMeter()
        running_acc = AverageMeter()

        initial_time = time.time()
        it = iter(train_loader)
        for i in tqdm.tqdm(range(len(train_loader))):
            try:
                images, labels = next(it)

                loaded_time = time.time()

                if torch.cuda.is_available():
                    images = images.cuda(non_blocking=True)
                    labels = labels.cuda(non_blocking=True)

                outputs = model(images)
                preds = get_num(outputs)
                loss = criterion(outputs, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                num_corrects = torch.sum(
                    torch.sum(preds == labels, dim=1) == NUM_SIZE)

                running_loss.update(loss.item(), images.size(0))
                running_acc.update(
                    num_corrects.item() / float(images.size(0)), images.size(0))

                trained_time = time.time()

                data_time.update(loaded_time - initial_time)
                train_time.update(trained_time - loaded_time)

                initial_time = time.time()

                writer.add_scalar("training/data_time",
                                  data_time.val, train_step)
                writer.add_scalar("training/train_time",
                                  train_time.val, train_step)
                writer.add_scalar("training/step_loss",
                                  running_loss.val, train_step)
                writer.add_scalar("training/step_acc",
                                  running_acc.val, train_step)
                train_step += 1
            except KeyboardInterrupt:
                raise
            except Exception:
                log.exception("Error processing data")
                raise

        epoch_loss = running_loss.avg
        epoch_acc = running_acc.avg

        log.info('Epoch {} Training: Loss: {:.4f} Acc: {:.4f}'.format(
            epoch, epoch_loss, epoch_acc))

        output_file = "{}/epoch_{}.pt".format(output_dir, epoch)
        log.info("Saving weights: {}".format(output_file))
        torch.save(net.state_dict(), output_file)

        writer.add_scalar("training/loss", epoch_loss, epoch)
        writer.add_scalar("training/accuracy", epoch_acc, epoch)

        model.eval()

        running_loss = AverageMeter()
        running_acc = AverageMeter()
        it = iter(val_loader)
        for i in tqdm.tqdm(range(len(val_loader))):
            try:
                images, labels = next(it)

                if torch.cuda.is_available():
                    images = images.cuda(non_blocking=True)
                    labels = labels.cuda(non_blocking=True)

                outputs = model(images)
                preds = get_num(outputs)
                loss = criterion(outputs, labels)

                num_corrects = torch.sum(
                    torch.sum(preds == labels, dim=1) == NUM_SIZE)

                running_loss.update(loss.item(), images.size(0))
                running_acc.update(
                    num_corrects.item() / float(images.size(0)), images.size(0))

                writer.add_scalar("validation/step_loss",
                                  running_loss.val, val_step)
                writer.add_scalar("validation/step_acc",
                                  running_acc.val, val_step)
                val_step += 1
            except KeyboardInterrupt:
                raise
            except Exception as e:
                log.exception("Error processing data")
                raise

        epoch_loss = running_loss.avg
        epoch_acc = running_acc.avg

        scheduler.step(epoch_loss)

        log.info('Epoch {} Validation: Loss: {:.4f} Acc: {:.4f}'.format(
            epoch, epoch_loss, epoch_acc))

        if epoch_acc > best_accuracy or (epoch_acc == best_accuracy and epoch_loss < best_loss):
            best_accuracy = epoch_acc
            best_loss = epoch_loss
            best_epoch = epoch
            log.info("Saving model with best validation accuracy {} {}".format(
                epoch_acc, epoch_loss))
            output_file = "{}/weights.pt".format(output_dir)
            torch.save(net.state_dict(), output_file)

        writer.add_scalar("validation/loss", epoch_loss, epoch)
        writer.add_scalar("validation/accuracy", epoch_acc, epoch)

    log.info("Training finished! Best Epoch: {} Acc: {}".format(
        best_epoch, best_accuracy))

    model.eval()
    net.load_state_dict(torch.load("{}/weights.pt".format(output_dir)))

    train_acc = evaluate(
        'training', model, yeller_train_set, output_dir, args, writer)
    val_acc = evaluate(
        'validation', model, yeller_val_set, output_dir, args, writer)
    test_acc = evaluate(
        'test', model, yeller_test_set, output_dir, args, writer)

    with open(os.path.join(output_dir, 'output.json'), 'w') as f:
        json.dump({
            'train_acc': train_acc,
            'val_acc': val_acc,
            'test_acc': test_acc
        }, f)

    log.info("Run: 'python scripts/deploy_detector.py --name {} --job_id {}' to deploy"
             .format(name, job_id))

    writer.close()


if __name__ == '__main__':
    main(parser.parse_args())
