
from bran.train.model_eval import evaluate_model
from datetime import datetime
from shutil import copyfile
import argparse
import functools
import json
import logging
import math
import multiprocessing
import os
import sys
import time
import traceback
import uuid
from bran.shared.logger import log
from bran.train import AverageMeter
from bran.train.mongo_dataset import MongoDataset, MalformedError
from bran.train.trainer import augmenter
from imgaug import augmenters as iaa
from tensorboardX import SummaryWriter
import boto3
import bran.shared.image_utils as ImageUtils
import cv2
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import tqdm
from bran.shared.config import config, get_config
config.SYSLOG_IDENT = 'bran-train'
config.MONGO_DB = get_config('prod').MONGO_DB


NAME_TO_CROP = {
    "kill": {"x": 0.3, "y": 0.6, "width": 0.4, "height": 0.25},
    "victory": {"x": 0.21875, "y": 0, "width": 0.5625, "height": 0.5},
    "gamestate": {"x": 0.5, "y": 0.5, "width": 0.5, "height": 0.5},
    "apex-victory": {"x": 0.25, "y": 0.25, "width": 0.5, "height": 0.5}
}

parser = argparse.ArgumentParser(description="train")
parser.add_argument("--name", action="store", required=True,
                    help="kill|gamestate|victory")
parser.add_argument("--job_id", action="store")
parser.add_argument("--dataset", required=True,
                    help="Fortnite_kill|Fortnite_victory|Fortnite_gamestate")
parser.add_argument("--labels", nargs="+", required=True,
                    help="space seperated list of labels")
parser.add_argument("--model_dir", action="store",
                    default="/var/bran/gamesense/models")
parser.add_argument("--epochs", action="store", type=int, default=30)
parser.add_argument('--workers', action='store', type=int,
                    default=multiprocessing.cpu_count())
parser.add_argument("--batch_size", action="store", type=int, default=128)
parser.add_argument('--lr', '--learning_rate', default=3e-4, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')


def main(args):
    name = args.name
    job_id = args.job_id or '{}_{}'.format(
        name, datetime.now().strftime('%Y%m%d_%H%M%S'))
    dataset = args.dataset
    labels = args.labels

    learning_rate = args.lr

    output_dir = os.path.join("{}/{}".format(args.model_dir, job_id))
    os.makedirs(output_dir, exist_ok=True)

    writer = SummaryWriter(output_dir)

    crop_dict = NAME_TO_CROP.get(args.name, None)

    train_transform = transforms.Compose([
        lambda img: ImageUtils.crop(img, crop_dict),
        lambda img: cv2.resize(img, (224, 224)),
        lambda img: cv2.cvtColor(img, cv2.COLOR_BGR2RGB),
        augmenter,
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])
    ])

    test_transform = transforms.Compose([
        lambda img: ImageUtils.crop(img, crop_dict),
        lambda img: cv2.resize(img, (224, 224)),
        lambda img: cv2.cvtColor(img, cv2.COLOR_BGR2RGB),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])
    ])

    train_set = MongoDataset(
        collection=dataset,
        split_type="train",
        labels=labels,
        stretch=True,
        transform=train_transform
    )

    class_weight = torch.Tensor(
        1e6 / np.array(train_set.get_class_weight()))
    if torch.cuda.is_available():
        class_weight = class_weight.cuda()

    val_set = MongoDataset(
        collection=dataset,
        split_type="val",
        labels=labels,
        transform=test_transform
    )

    test_set = MongoDataset(
        collection=dataset,
        split_type="test",
        labels=labels,
        transform=test_transform
    )

    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size, num_workers=args.workers, shuffle=True, pin_memory=torch.cuda.is_available())

    val_loader = torch.utils.data.DataLoader(
        val_set, batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=torch.cuda.is_available())

    test_loader = torch.utils.data.DataLoader(
        test_set, batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=torch.cuda.is_available())

    net = torchvision.models.resnet18(pretrained=True)

    # Freeze all params
    for param in net.parameters():
        param.requires_grad = False
    # Replace FC layer
    net.fc = torch.nn.Linear(net.fc.in_features, len(args.labels))

    writer.add_graph(net, torch.stack([train_set[i][0] for i in range(10)]))

    criterion = torch.nn.CrossEntropyLoss(weight=class_weight)
    optimizer = torch.optim.Adam(
        net.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, verbose=True)

    log.info("--------------------------------")
    log.info("pytorch version: {}".format(torch.__version__))
    log.info("pytorch cuda: {}".format(torch.cuda.is_available()))
    log.info("Using Model: {}".format(name))
    log.info("Using Dataset: {}".format(dataset))
    log.info("labels: {}".format(labels))
    log.info("output directory: {}".format(output_dir))
    log.info("train dataset: {} images".format(len(train_set)))
    log.info("train weight: {}".format(class_weight))
    log.info("val dataset: {} images".format(len(val_set)))
    log.info("test dataset: {} images".format(len(test_set)))
    log.info("total epochs {}".format(args.epochs))
    log.info("--------------------------------------------")

    model = torch.nn.DataParallel(net)
    if torch.cuda.is_available():
        model = model.cuda()

    log.info("created model: {}".format(model))

    train_step = 0
    val_step = 0
    best_accuracy = 0
    best_loss = float('inf')
    best_epoch = 0

    data_time = AverageMeter()
    train_time = AverageMeter()

    for epoch in range(args.epochs):
        model.train()

        writer.add_scalar("training/learning_rate",
                          optimizer.param_groups[0]['lr'], epoch)

        if epoch == 5:
            log.info("UNFREEZING TOP WEIGHT")
            for param in net.layer4.parameters():
                param.requires_grad = True
            for param in net.avgpool.parameters():
                param.requires_grad = True

        if epoch == 10:
            log.info("UNFREEZING MORE WEIGHT")
            for param in net.layer3.parameters():
                param.requires_grad = True
            for param in net.layer2.parameters():
                param.requires_grad = True
            for param in net.layer1.parameters():
                param.requires_grad = True

        if epoch == 15:
            log.info("UNFREEZING ALL WEIGHT")
            for param in net.parameters():
                param.requires_grad = True

        running_loss = AverageMeter()
        running_acc = AverageMeter()

        initial_time = time.time()
        it = iter(train_loader)
        for i in tqdm.tqdm(range(len(train_loader))):
            try:
                images, labels, _indexes = next(it)

                loaded_time = time.time()

                if torch.cuda.is_available():
                    images = images.cuda(non_blocking=True)
                    labels = labels.cuda(non_blocking=True)

                outputs = model(images)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                running_loss.update(loss.item(), images.size(0))
                running_acc.update(torch.sum(preds == labels.data).item(
                ) / float(images.size(0)), images.size(0))

                trained_time = time.time()

                data_time.update(loaded_time - initial_time)
                train_time.update(trained_time - loaded_time)

                initial_time = time.time()

                writer.add_scalar("training/data_time",
                                  data_time.val, train_step)
                writer.add_scalar("training/train_time",
                                  train_time.val, train_step)
                writer.add_scalar("training/step_loss",
                                  running_loss.val, train_step)
                writer.add_scalar("training/step_acc",
                                  running_acc.val, train_step)
                train_step += 1
            except KeyboardInterrupt:
                raise
            except Exception:
                log.exception("Error processing data")
                raise

        epoch_loss = running_loss.avg
        epoch_acc = running_acc.avg

        log.info('Epoch {} Training: Loss: {:.4f} Acc: {:.4f}'.format(
            epoch, epoch_loss, epoch_acc))

        output_file = "{}/epoch_{}.pt".format(output_dir, epoch)
        log.info("Saving weights: {}".format(output_file))
        torch.save(net.state_dict(), output_file)

        writer.add_scalar("training/loss", epoch_loss, epoch)
        writer.add_scalar("training/accuracy", epoch_acc, epoch)

        model.eval()

        running_loss = AverageMeter()
        running_acc = AverageMeter()
        it = iter(val_loader)
        for i in tqdm.tqdm(range(len(val_loader))):
            try:
                images, labels, _indexes = next(it)

                if torch.cuda.is_available():
                    images = images.cuda(non_blocking=True)
                    labels = labels.cuda(non_blocking=True)

                outputs = model(images)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                running_loss.update(loss.item(), images.size(0))
                running_acc.update(torch.sum(preds == labels.data).item(
                ) / float(images.size(0)), images.size(0))

                writer.add_scalar("validation/step_loss",
                                  running_loss.val, val_step)
                writer.add_scalar("validation/step_acc",
                                  running_acc.val, val_step)
                val_step += 1
            except KeyboardInterrupt:
                raise
            except Exception as e:
                log.exception("Error processing data")
                raise

        epoch_loss = running_loss.avg
        epoch_acc = running_acc.avg

        scheduler.step(epoch_loss)

        log.info('Epoch {} Validation: Loss: {:.4f} Acc: {:.4f}'.format(
            epoch, epoch_loss, epoch_acc))

        if epoch_acc > best_accuracy or (epoch_acc == best_accuracy and epoch_loss < best_loss):
            best_accuracy = epoch_acc
            best_loss = epoch_loss
            best_epoch = epoch
            log.info("Saving model with best validation accuracy")
            output_file = "{}/weights.pt".format(output_dir)
            torch.save(net.state_dict(), output_file)

        writer.add_scalar("validation/loss", epoch_loss, epoch)
        writer.add_scalar("validation/accuracy", epoch_acc, epoch)

    log.info("Training finished! Best Epoch: {} Acc: {}".format(
        best_epoch, best_accuracy))

    model.eval()
    net.load_state_dict(torch.load("{}/weights.pt".format(output_dir)))

    # unstretched for eval
    train_set_eval = MongoDataset(
        collection=dataset,
        split_type="train",
        labels=args.labels,
        transform=test_transform
    )

    train_loader_eval = torch.utils.data.DataLoader(
        train_set_eval, batch_size=args.batch_size, num_workers=args.workers, pin_memory=torch.cuda.is_available())

    train_acc = evaluate_model(
        'training', model, train_loader_eval, args.labels, output_dir, writer)
    val_acc = evaluate_model(
        'validation', model, val_loader, args.labels, output_dir, writer)
    test_acc = evaluate_model(
        'test', model, test_loader, args.labels, output_dir, writer)

    with open(os.path.join(output_dir, 'output.json'), 'w') as f:
        json.dump({
            'train_acc': train_acc,
            'val_acc': val_acc,
            'test_acc': test_acc
        }, f)

    log.info("Run: 'python scripts/deploy_detector.py --name {} --job_id {}' to deploy"
             .format(name, job_id))

    writer.close()


if __name__ == '__main__':
    main(parser.parse_args())
