
from bran.shared.config import config, get_config
config.SYSLOG_IDENT = 'bran-train'
config.MONGO_DB = get_config('prod').MONGO_DB

import tqdm
import torchvision.transforms as transforms
import torchvision
import torch
import numpy as np
import cv2
import bran.shared.image_utils as ImageUtils
import boto3
from tensorboardX import SummaryWriter
from imgaug import augmenters as iaa
from bran.train.trainer import augmenter
from bran.train.mongo_dataset import MongoDataset, MalformedError
from bran.train import AverageMeter
from bran.shared.logger import log
import uuid
import traceback
import time
import sys
import os
import multiprocessing
import math
import logging
import json
import functools
import argparse
from shutil import copyfile
from datetime import datetime
from bran.train.model_eval import evaluate_model

NAME_TO_CROP = {
    "kill": {"x": 0.3, "y": 0.6, "width": 0.4, "height": 0.25},
    "victory": {"x": 0.21875, "y": 0, "width": 0.5625, "height": 0.5},
    "gamestate": {"x": 0.5, "y": 0.5, "width": 0.5, "height": 0.5},
    "gameinfo": {"x": 0.8, "y": 0, "width": 0.2, "height": 0.2}
}

parser = argparse.ArgumentParser(description="train")
parser.add_argument("--name", action="store", required=True,
                    help="kill|gamestate|victory")
parser.add_argument("--job_id", action="store")
# parser.add_argument("--dataset", required=True,
#                     help="Fortnite_kill|Fortnite_victory|Fortnite_gamestate")
parser.add_argument("--labels", nargs="+", required=True,
                    help="space seperated list of labels")
parser.add_argument("--model_dir", action="store",
                    default="/var/gamesense/models")
parser.add_argument("--epochs", action="store", type=int, default=30)
parser.add_argument('--workers', action='store', type=int,
                    default=multiprocessing.cpu_count())
parser.add_argument("--batch_size", action="store", type=int, default=128)
parser.add_argument('--lr', '--learning_rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')


import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from bran.train.ptsemseg_utils import (
    get_interp_size,
    cascadeFeatureFusion,
    conv2DBatchNormRelu,
    residualBlockPSP,
    pyramidPooling,
)

def cross_entropy2d(input, target, weight=None, size_average=True):
    n, c, h, w = input.size()
    nt, ht, wt = target.size()

    # Handle inconsistent size between input and target
    if h != ht and w != wt:  # upsample labels
        input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True)

    input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
    target = target.view(-1)
    loss = F.cross_entropy(
        input, target, weight=weight, size_average=size_average, ignore_index=250
    )
    return loss

def multi_scale_cross_entropy2d(input, target, weight=None, size_average=True, scale_weight=None):
    if not isinstance(input, tuple):
        return cross_entropy2d(input=input, target=target, weight=weight, size_average=size_average)

    # Auxiliary training for PSPNet [1.0, 0.4] and ICNet [1.0, 0.4, 0.16]
    if scale_weight is None:  # scale_weight: torch tensor type
        n_inp = len(input)
        scale = 0.4
        scale_weight = torch.pow(scale * torch.ones(n_inp), torch.arange(n_inp).float()).to(
            target.device
        )

    loss = 0.0
    for i, inp in enumerate(input):
        loss = loss + scale_weight[i] * cross_entropy2d(
            input=inp, target=target, weight=weight, size_average=size_average
        )

    return loss

class icnet(nn.Module):

    """
    Image Cascade Network
    URL: https://arxiv.org/abs/1704.08545
    References:
    1) Original Author's code: https://github.com/hszhao/ICNet
    2) Chainer implementation by @mitmul: https://github.com/mitmul/chainer-pspnet
    3) TensorFlow implementation by @hellochick: https://github.com/hellochick/ICNet-tensorflow
    """

    def __init__(
        self,
        n_classes=19,
        block_config=[3, 4, 6, 3],
        input_size=(1025, 2049),
        version=None,
        is_batchnorm=True,
    ):

        super(icnet, self).__init__()

        bias = not is_batchnorm

        self.block_config = (
            icnet_specs[version]["block_config"] if version is not None else block_config
        )
        self.n_classes = icnet_specs[version]["n_classes"] if version is not None else n_classes
        self.input_size = icnet_specs[version]["input_size"] if version is not None else input_size

        # Encoder
        self.convbnrelu1_1 = conv2DBatchNormRelu(
            in_channels=3,
            k_size=3,
            n_filters=32,
            padding=1,
            stride=2,
            bias=bias,
            is_batchnorm=is_batchnorm,
        )
        self.convbnrelu1_2 = conv2DBatchNormRelu(
            in_channels=32,
            k_size=3,
            n_filters=32,
            padding=1,
            stride=1,
            bias=bias,
            is_batchnorm=is_batchnorm,
        )
        self.convbnrelu1_3 = conv2DBatchNormRelu(
            in_channels=32,
            k_size=3,
            n_filters=64,
            padding=1,
            stride=1,
            bias=bias,
            is_batchnorm=is_batchnorm,
        )

        # Vanilla Residual Blocks
        self.res_block2 = residualBlockPSP(
            self.block_config[0], 64, 32, 128, 1, 1, is_batchnorm=is_batchnorm
        )
        self.res_block3_conv = residualBlockPSP(
            self.block_config[1],
            128,
            64,
            256,
            2,
            1,
            include_range="conv",
            is_batchnorm=is_batchnorm,
        )
        self.res_block3_identity = residualBlockPSP(
            self.block_config[1],
            128,
            64,
            256,
            2,
            1,
            include_range="identity",
            is_batchnorm=is_batchnorm,
        )

        # Dilated Residual Blocks
        self.res_block4 = residualBlockPSP(
            self.block_config[2], 256, 128, 512, 1, 2, is_batchnorm=is_batchnorm
        )
        self.res_block5 = residualBlockPSP(
            self.block_config[3], 512, 256, 1024, 1, 4, is_batchnorm=is_batchnorm
        )

        # Pyramid Pooling Module
        self.pyramid_pooling = pyramidPooling(
            1024, [6, 3, 2, 1], model_name="icnet", fusion_mode="sum", is_batchnorm=is_batchnorm
        )

        # Final conv layer with kernel 1 in sub4 branch
        self.conv5_4_k1 = conv2DBatchNormRelu(
            in_channels=1024,
            k_size=1,
            n_filters=256,
            padding=0,
            stride=1,
            bias=bias,
            is_batchnorm=is_batchnorm,
        )

        # High-resolution (sub1) branch
        self.convbnrelu1_sub1 = conv2DBatchNormRelu(
            in_channels=3,
            k_size=3,
            n_filters=32,
            padding=1,
            stride=2,
            bias=bias,
            is_batchnorm=is_batchnorm,
        )
        self.convbnrelu2_sub1 = conv2DBatchNormRelu(
            in_channels=32,
            k_size=3,
            n_filters=32,
            padding=1,
            stride=2,
            bias=bias,
            is_batchnorm=is_batchnorm,
        )
        self.convbnrelu3_sub1 = conv2DBatchNormRelu(
            in_channels=32,
            k_size=3,
            n_filters=64,
            padding=1,
            stride=2,
            bias=bias,
            is_batchnorm=is_batchnorm,
        )
        self.classification = nn.Conv2d(128, self.n_classes, 1, 1, 0)

        # Cascade Feature Fusion Units
        self.cff_sub24 = cascadeFeatureFusion(
            self.n_classes, 256, 256, 128, is_batchnorm=is_batchnorm
        )
        self.cff_sub12 = cascadeFeatureFusion(
            self.n_classes, 128, 64, 128, is_batchnorm=is_batchnorm
        )

        # Define auxiliary loss function
        self.loss = multi_scale_cross_entropy2d

    def forward(self, x):
        h, w = x.shape[2:]

        # H, W -> H/2, W/2
        x_sub2 = F.interpolate(
            x, size=get_interp_size(x, s_factor=2), mode="bilinear", align_corners=True
        )

        # H/2, W/2 -> H/4, W/4
        x_sub2 = self.convbnrelu1_1(x_sub2)
        x_sub2 = self.convbnrelu1_2(x_sub2)
        x_sub2 = self.convbnrelu1_3(x_sub2)

        # H/4, W/4 -> H/8, W/8
        x_sub2 = F.max_pool2d(x_sub2, 3, 2, 1)

        # H/8, W/8 -> H/16, W/16
        x_sub2 = self.res_block2(x_sub2)
        x_sub2 = self.res_block3_conv(x_sub2)
        # H/16, W/16 -> H/32, W/32
        x_sub4 = F.interpolate(
            x_sub2, size=get_interp_size(x_sub2, s_factor=2), mode="bilinear", align_corners=True
        )
        x_sub4 = self.res_block3_identity(x_sub4)

        x_sub4 = self.res_block4(x_sub4)
        x_sub4 = self.res_block5(x_sub4)

        x_sub4 = self.pyramid_pooling(x_sub4)
        x_sub4 = self.conv5_4_k1(x_sub4)

        x_sub1 = self.convbnrelu1_sub1(x)
        x_sub1 = self.convbnrelu2_sub1(x_sub1)
        x_sub1 = self.convbnrelu3_sub1(x_sub1)

        x_sub24, sub4_cls = self.cff_sub24(x_sub4, x_sub2)
        x_sub12, sub24_cls = self.cff_sub12(x_sub24, x_sub1)

        x_sub12 = F.interpolate(
            x_sub12, size=get_interp_size(x_sub12, z_factor=2), mode="bilinear", align_corners=True
        )
        x_sub4 = self.res_block3_identity(x_sub4)
        sub124_cls = self.classification(x_sub12)

        if self.training:
            return (sub124_cls, sub24_cls, sub4_cls)
        else:
            sub124_cls = F.interpolate(
                sub124_cls,
                size=get_interp_size(sub124_cls, z_factor=4),
                mode="bilinear",
                align_corners=True,
            )
            return sub124_cls

    def tile_predict(self, imgs, include_flip_mode=True):
        """
        Predict by takin overlapping tiles from the image.
        Strides are adaptively computed from the imgs shape
        and input size
        :param imgs: torch.Tensor with shape [N, C, H, W] in BGR format
        :param side: int with side length of model input
        :param n_classes: int with number of classes in seg output.
        """

        side_x, side_y = self.input_size
        n_classes = self.n_classes
        n_samples, c, h, w = imgs.shape
        # n = int(max(h,w) / float(side) + 1)
        n_x = int(h / float(side_x) + 1)
        n_y = int(w / float(side_y) + 1)
        stride_x = (h - side_x) / float(n_x)
        stride_y = (w - side_y) / float(n_y)

        x_ends = [[int(i * stride_x), int(i * stride_x) + side_x] for i in range(n_x + 1)]
        y_ends = [[int(i * stride_y), int(i * stride_y) + side_y] for i in range(n_y + 1)]

        pred = np.zeros([n_samples, n_classes, h, w])
        count = np.zeros([h, w])

        slice_count = 0
        for sx, ex in x_ends:
            for sy, ey in y_ends:
                slice_count += 1

                imgs_slice = imgs[:, :, sx:ex, sy:ey]
                if include_flip_mode:
                    imgs_slice_flip = torch.from_numpy(
                        np.copy(imgs_slice.cpu().numpy()[:, :, :, ::-1])
                    ).float()

                is_model_on_cuda = next(self.parameters()).is_cuda

                inp = Variable(imgs_slice, volatile=True)
                if include_flip_mode:
                    flp = Variable(imgs_slice_flip, volatile=True)

                if is_model_on_cuda:
                    inp = inp.cuda()
                    if include_flip_mode:
                        flp = flp.cuda()

                psub1 = F.softmax(self.forward(inp), dim=1).data.cpu().numpy()
                if include_flip_mode:
                    psub2 = F.softmax(self.forward(flp), dim=1).data.cpu().numpy()
                    psub = (psub1 + psub2[:, :, :, ::-1]) / 2.0
                else:
                    psub = psub1

                pred[:, :, sx:ex, sy:ey] = psub
                count[sx:ex, sy:ey] += 1.0

        score = (pred / count[None, None, ...]).astype(np.float32)
        return score / np.expand_dims(score.sum(axis=1), axis=1)

def main(args):
    name = args.name
    job_id = args.job_id or '{}_{}'.format(
        name, datetime.now().strftime('%Y%m%d_%H%M%S'))
    # labels = args.labels

    learning_rate = args.lr
    momentum = args.momentum

    output_dir = os.path.join("{}/{}".format(args.model_dir, job_id))
    os.makedirs(output_dir, exist_ok=True)

    writer = SummaryWriter(output_dir)

    crop_dict = NAME_TO_CROP.get(args.name, None)

    train_transform = transforms.Compose([
        lambda img: ImageUtils.crop(img, crop_dict),
        lambda img: cv2.resize(img, (225, 225)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])
    ])

    mapdata = np.zeros((225, 225))
    mapbox = [43, 22, 191, 171]
    timebox = [70, 175, 104, 192]
    killbox = [125, 175, 145, 192]
    mapdata[mapbox[1]:mapbox[3], mapbox[0]:mapbox[2]] = 1
    mapdata[timebox[1]:timebox[3], timebox[0]:timebox[2]] = 2
    mapdata[killbox[1]:killbox[3], killbox[0]:killbox[2]] = 3

    target_transform = transforms.Compose([
        lambda x: mapdata,
        lambda x: torch.LongTensor(x)
    ])

    test_transform = transforms.Compose([
        lambda img: ImageUtils.crop(img, crop_dict),
        lambda img: cv2.resize(img, (225, 225)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])
    ])

    train_set = torchvision.datasets.ImageFolder('images',
        transform=train_transform, target_transform=target_transform, loader=cv2.imread)

    n_classes = 4
    def get_class_weight(data):
        classes = np.zeros((n_classes))

        for i in range(n_classes):
            classes[i] = np.sum(data == i)

        m = np.max(classes)

        return m / classes


    class_weight = torch.Tensor(np.array(get_class_weight(mapdata)))
    if torch.cuda.is_available():
        class_weight = class_weight.cuda()

    # val_set = MongoDataset(
    #     collection=dataset,
    #     split_type="val",
    #     labels=labels,
    #     transform=test_transform
    # )

    # test_set = MongoDataset(
    #     collection=dataset,
    #     split_type="test",
    #     labels=labels,
    #     transform=test_transform
    # )

    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size, num_workers=args.workers, shuffle=True, pin_memory=torch.cuda.is_available())

    # val_loader = torch.utils.data.DataLoader(
    #     val_set, batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=torch.cuda.is_available())

    # test_loader = torch.utils.data.DataLoader(
    #     test_set, batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=torch.cuda.is_available())

    net = icnet(n_classes=4, input_size=(225, 225))

    # writer.add_graph(net, torch.stack([train_set[i][0] for i in range(10)]))

    criterion = net.loss
    optimizer = torch.optim.SGD(
        net.parameters(), lr=learning_rate, momentum=momentum)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, verbose=True)

    log.info("--------------------------------")
    log.info("pytorch version: {}".format(torch.__version__))
    log.info("pytorch cuda: {}".format(torch.cuda.is_available()))
    log.info("Using Model: {}".format(name))
    log.info("labels: {}".format(args.labels))
    log.info("output directory: {}".format(output_dir))
    log.info("train dataset: {} images".format(len(train_set)))
    log.info("train weight: {}".format(class_weight))
    # log.info("val dataset: {} images".format(len(val_set)))
    # log.info("test dataset: {} images".format(len(test_set)))
    log.info("total epochs {}".format(args.epochs))
    log.info("--------------------------------------------")

    model = torch.nn.DataParallel(net)
    if torch.cuda.is_available():
        model = model.cuda()

    # log.info("created model: {}".format(model))

    train_step = 0
    val_step = 0
    best_accuracy = 0
    best_epoch = 0

    data_time = AverageMeter()
    train_time = AverageMeter()

    for epoch in range(args.epochs):
        model.train()

        writer.add_scalar("training/learning_rate",
                          optimizer.param_groups[0]['lr'], epoch)

        running_loss = AverageMeter()
        running_acc = AverageMeter()

        initial_time = time.time()
        it = iter(train_loader)
        for i in tqdm.tqdm(range(len(train_loader))):
            if (train_step + 1) % 3 == 0:
                break
            try:
                images, labels = next(it)
                loaded_time = time.time()

                if torch.cuda.is_available():
                    images = images.cuda(non_blocking=True)
                    labels = labels.cuda(non_blocking=True)

                outputs = model(images)
                outputs_pred = F.interpolate(
                    outputs[0],
                    size=get_interp_size(outputs[0], z_factor=4),
                    mode="bilinear",
                    align_corners=True,
                )

                _, preds = torch.max(outputs_pred, 1)
                loss = criterion(outputs, labels, weight=class_weight)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                size = labels.cpu().detach().numpy().size
                running_loss.update(loss.item(), images.size(0))
                running_acc.update(torch.sum(preds == labels.data).item(
                ) / float(size), size)

                trained_time = time.time()

                data_time.update(loaded_time - initial_time)
                train_time.update(trained_time - loaded_time)

                initial_time = time.time()

                writer.add_scalar("training/data_time",
                                  data_time.val, train_step)
                writer.add_scalar("training/train_time",
                                  train_time.val, train_step)
                writer.add_scalar("training/step_loss",
                                  running_loss.val, train_step)
                writer.add_scalar("training/step_acc",
                                  running_acc.val, train_step)
                train_step += 1
            except KeyboardInterrupt:
                raise
            except Exception:
                log.exception("Error processing data")
                raise

        epoch_loss = running_loss.avg
        epoch_acc = running_acc.avg

        log.info('Epoch {} Training: Loss: {:.4f} Acc: {:.4f}'.format(
            epoch, epoch_loss, epoch_acc))

        output_file = "{}/epoch_{}.pt".format(output_dir, epoch)
        log.info("Saving weights: {}".format(output_file))
        torch.save(net.state_dict(), output_file)

        writer.add_scalar("training/loss", epoch_loss, epoch)
        writer.add_scalar("training/accuracy", epoch_acc, epoch)

        # model.eval()

        # running_loss = AverageMeter()
        # running_acc = AverageMeter()
        # it = iter(val_loader)
        # for i in tqdm.tqdm(range(len(val_loader))):
        #     try:
        #         images, labels, _indexes = next(it)

        #         if torch.cuda.is_available():
        #             images = images.cuda(non_blocking=True)
        #             labels = labels.cuda(non_blocking=True)

        #         outputs = model(images)
        #         _, preds = torch.max(outputs, 1)
        #         loss = criterion(outputs, labels)

        #         running_loss.update(loss.item(), images.size(0))
        #         running_acc.update(torch.sum(preds == labels.data).item(
        #         ) / float(images.size(0)), images.size(0))

        #         writer.add_scalar("validation/step_loss",
        #                             running_loss.val, val_step)
        #         writer.add_scalar("validation/step_acc",
        #                             running_acc.val, val_step)
        #         val_step += 1
        #     except KeyboardInterrupt:
        #         raise
        #     except Exception as e:
        #         log.exception("Error processing data")
        #         raise

        # epoch_loss = running_loss.avg
        # epoch_acc = running_acc.avg

        # scheduler.step(epoch_loss)

        # log.info('Epoch {} Validation: Loss: {:.4f} Acc: {:.4f}'.format(
        #     epoch, epoch_loss, epoch_acc))

        if epoch_acc > best_accuracy:
            best_accuracy = epoch_acc
            best_epoch = epoch
            log.info("Saving model with best validation accuracy")
            output_file = "{}/weights.pt".format(output_dir)
            torch.save(net.state_dict(), output_file)

        # writer.add_scalar("validation/loss", epoch_loss, epoch)
        # writer.add_scalar("validation/accuracy", epoch_acc, epoch)

    log.info("Training finished! Best Epoch: {} Acc: {}".format(
        best_epoch, best_accuracy))

    model.eval()
    net.load_state_dict(torch.load("{}/weights.pt".format(output_dir)))

    # unstretched for eval
    train_set_eval = torchvision.datasets.ImageFolder('images',
        transform=test_transform, target_transform=target_transform, loader=cv2.imread)

    train_loader_eval = torch.utils.data.DataLoader(
        train_set_eval, batch_size=args.batch_size, num_workers=args.workers, pin_memory=torch.cuda.is_available())

    # train_acc = evaluate_model('training', model, train_loader_eval, args.labels, output_dir, writer)
    # val_acc = evaluate_model('validation', model, val_loader, args.labels, output_dir, writer)
    # test_acc = evaluate_model('test', model, test_loader, args.labels, output_dir, writer)

    with open(os.path.join(output_dir, 'output.json'), 'w') as f:
        json.dump({
            'train_acc': train_acc,
            'val_acc': val_acc,
            'test_acc': test_acc
        }, f)

    log.info("Run: 'python scripts/deploy_detector.py --name {} --job_id {}' to deploy"
             .format(name, job_id))

    writer.close()

if __name__ == '__main__':
    main(parser.parse_args())
