from bran.shared.logger import log
from bran.train import AverageMeter
from sklearn.metrics import confusion_matrix, accuracy_score
import itertools
import json
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import tempfile
import torch
import tqdm
import os
import copy

def errors_to_json(dataset, indexes, output_file):
    images = dataset.dataset.images # hax lol
    data = []
    for i in indexes:
        obj = copy.deepcopy(images[i])
        if '_id' in obj:
            del obj['_id'] # hack for mongo bson
        data.append(obj)

    with open(output_file, 'w') as f:
        f.write(json.dumps(data))
    return output_file


def plot_confusion_matrix(cm, classes):
    """
    This function plots the confusion matrix.
    """

    fig = plt.figure()
    plt.imshow(cm, interpolation='nearest', cmap=plt.get_cmap('Blues'))
    plt.title('Confusion matrix')
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.3f'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    return fig


def evaluate_model(name, model, dataset_loader, classes, output_dir, writer):
    log.info("Evaluating model with {} set".format(name))
    LOW_PROBABILITY_THREASHOLD = 0.95

    errors = []
    low_probs = []

    all_preds = []
    all_labels = []

    model.eval()
    it = iter(dataset_loader)
    for i in tqdm.tqdm(range(len(dataset_loader))):
        try:
            images, labels, indexes = next(it)

            if torch.cuda.is_available():
                images = images.cuda(non_blocking=True)
                labels = labels.cuda(non_blocking=True)

            outputs = model(images)
            _, preds = torch.max(outputs, 1)

            indexes = indexes.numpy()
            predicted = preds.cpu().detach().numpy().astype(np.uint8)
            labels = labels.cpu().detach().numpy().astype(np.uint8)

            all_preds.append(predicted)
            all_labels.append(labels)

            errors.extend(indexes[predicted != labels])

            probs = torch.nn.functional.softmax(
                outputs, 1).cpu().detach().numpy().astype(np.float)
            low_probs.extend(
                indexes[np.max(probs, axis=1) < LOW_PROBABILITY_THREASHOLD])
        except KeyboardInterrupt:
            raise
        except Exception as e:
            log.exception("Error processing data")
            raise

    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)

    log.info("Total data: {}".format(len(dataset_loader)))

    # Plot normalized confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    np.set_printoptions(precision=2)
    cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    log.info("Confusion Matrix: {}".format(cm))

    fig = plot_confusion_matrix(cm, classes)
    writer.add_figure("{}/confusion_matrix".format(name), fig)

    accurarcy_perc = 100 * accuracy_score(all_labels, all_preds)

    log.info('Accuracy: {:.2f}%'.format(accurarcy_perc))
    log.info('Errors: {}'.format(errors))
    log.info('Errors Json: {}'.format(
        errors_to_json(dataset_loader, errors,
                       os.path.join(output_dir, '{}.{}.json'.format(name, 'errors')))))
    log.info('Low Probablities Json: {}'.format(
        errors_to_json(dataset_loader, low_probs,
                       os.path.join(output_dir, '{}.{}.json'.format(name, 'low_probs')))))

    return accurarcy_perc
