import glob
import json
import logging
import os
import time
import argparse
import pprint
import tempfile

import cv2
import boto3
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import tqdm

from bran.shared.config import config, get_config
config.SYSLOG_IDENT = 'bran-train'
config.MONGO_DB = get_config('prod').MONGO_DB

from bran.shared.logger import log
from bran.train.mongo_dataset import MongoDataset, MalformedError
from bran.train.nets import model_from_config
from bran.train import get_best_weight_file
import bran.shared.image_utils as ImageUtils

parser = argparse.ArgumentParser(description="")
parser.add_argument("--name", action="store", required=True, help="kill")
parser.add_argument("--job_id", action="store", help="kill_1234")
parser.add_argument("--model_weights", action="store", help="weight file to test")
parser.add_argument("--dataset", action="store", help="dataset")

logging.getLogger("boto3").setLevel(logging.ERROR)
logging.getLogger("botocore").setLevel(logging.ERROR)
logging.getLogger("nose").setLevel(logging.ERROR)
logging.getLogger("s3transfer").setLevel(logging.ERROR)


def errors_to_json(dataset, indexes):
    data = []
    for index in indexes:
        image_data = dataset.images[index]
        data.append({
            'key': image_data['key'],
            'url': 'https://s3-us-west-1.amazonaws.com/{}/{}'.format(image_data['bucket'], image_data['key']),
            'label': image_data["label"],
        })
    _, tf = tempfile.mkstemp('.json')
    with open(tf, 'w') as f:
        f.write(json.dumps(data))
    return tf


class ModelEvaluator(object):
    def __init__(self, **kwargs):
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.database = kwargs["database"]
        self.weights = kwargs['weights']

        if kwargs['name'] == 'kill':
            self.labels = ["no", "yes"]
            self.collection = "Fortnite_kill"

            self.transform = transforms.Compose([
                lambda img: ImageUtils.crop(img, { "x": 0.3, "y": 0.6, "width": 0.4, "height": 0.25 }),
                lambda img: cv2.resize(img, (224, 224)),
                lambda img: cv2.cvtColor(img, cv2.COLOR_BGR2RGB),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225])
            ])

            self.net = torchvision.models.resnet18(pretrained=True)
            self.net.fc = torch.nn.Linear(self.net.fc.in_features, len(self.labels))
            self.net.eval()
            self.net.to(self.device)
            self.net.load_state_dict(torch.load(self.weights))
        if kwargs['name'] == 'gamestate':
            self.labels = ['game', 'lobby', 'no']
            self.collection = 'Fortnite_gamestate'

            self.transform = transforms.Compose([
                lambda img: cv2.resize(img, (224, 224)),
                lambda img: cv2.cvtColor(img, cv2.COLOR_BGR2RGB),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225])
            ])

            self.net = torchvision.models.resnet18(pretrained=True)
            self.net.fc = torch.nn.Linear(self.net.fc.in_features, len(self.labels))
            self.net.eval()
            self.net.to(self.device)
            self.net.load_state_dict(torch.load(self.weights))

        self.test_set = MongoDataset(
            collection=self.collection,
            database=self.database,
            labels=self.labels,
            transform=self.transform)

        self.test_loader = torch.utils.data.DataLoader(
            self.test_set,
            batch_size=64,
            num_workers=4,
            shuffle=False)

    def eval_model(self):
        TP = 0
        TN = 0
        FP = 0
        FN = 0
        errors = []
        low_probs = []

        it = iter(self.test_loader)
        for i in tqdm.tqdm(range(len(self.test_loader))):
            try:
                images, labels, indexes = next(it)
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.net(images)
                _, predicted = torch.max(outputs.cpu(), 1)

                indexes = indexes.numpy()
                predicted = predicted.numpy().astype(np.uint8)
                labels = labels.cpu().numpy().astype(np.uint8)

                TP += np.sum(np.logical_and(predicted == 1, labels == 1))
                TN += np.sum(np.logical_and(predicted == 0, labels == 0))
                FP += np.sum(np.logical_and(predicted == 1, labels == 0))
                FN += np.sum(np.logical_and(predicted == 0, labels == 1))

                errors.extend(indexes[predicted != labels])

                probs = torch.nn.functional.softmax(
                    outputs, 1).cpu().detach().numpy().astype(np.float)
                low_probs.extend(indexes[np.max(probs, axis=1) < 0.95])
            except KeyboardInterrupt:
                raise
            except MalformedError:
                log.error("Skipping malformed batch")
            except Exception:
                log.error("Error: ", exc_info=True)

        log.info("Total Tests: {}".format(len(self.test_set)))
        log.info('Accuracy: {:.2f}%'.format(
            (TP + TN) / len(self.test_set) * 100))
        log.info('TP: {}, TN: {}, FP: {},FN: {}'.format(TP, TN, FP, FN))
        log.info('Errors: {}'.format(errors))

        log.info('Errors Json: {}'.format(
            errors_to_json(self.test_set, errors)))
        log.info('Low Probablities Json: {}'.format(
            errors_to_json(self.test_set, low_probs)))


if __name__ == '__main__':
    args = parser.parse_args()
    model_weights = args.model_weights or get_best_weight_file('models/{}'.format(args.job_id))

    train_evaluator = ModelEvaluator(
        database="classified",
        name=args.name,
        weights=model_weights
    )

    test_evaluator = ModelEvaluator(
        database="validation",
        name=args.name,
        weights=model_weights
    )

    log.info("--------------------------------------------")
    log.info("pytorch version: {}".format(torch.__version__))
    log.info("Using Model Weights: {}".format(model_weights))
    log.info("--------------------------------------------")

    log.info("Training Set:")
    train_evaluator.eval_model()

    log.info("--------------------------------------------")
    log.info("Testing Set:")
    test_evaluator.eval_model()
