import json
import re
import os

from collections import namedtuple, defaultdict, Counter
from itertools import chain, islice
from functools import reduce
import csv
import random
import numpy as np

import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, accuracy_score, precision_recall_curve
import seaborn as sns
import catboost

IGNORE_USERS_LIST = ['quoter', 'tmp1', 'tmpuser2']

MarkupTaskKey = namedtuple("MarkupTaskKey", "task, object_id")
MarkupResult = namedtuple("MarkupResult", "task_key, user, is_bad")
MetricScore = namedtuple("MetricScore", "score, iou, residual, shift_x, shift_y, theta, scale, points_diff, polygon")
DatasetItem = namedtuple("DatasetItem", "key, markup_stat, polygon, metric_score")


class MarkupResultStat(object):

    def __init__(self, **kwargs):
        self.good_votes = set(kwargs.get('good_votes', []))
        self.bad_votes = set(kwargs.get('bad_votes', []))

    def _asdict(self):
        return dict(good_votes = list(self.good_votes),
                    bad_votes = list(self.bad_votes))

    def __eq__(self, other):
        return (self.good_votes, self.bad_votes) == \
            (other.good_votes, other.bad_votes)

    def __str__(self):
        return "good_votes={0} bad_votes={1}".format(self.good_votes, self.bad_votes)

    def __repr__(self):
        return self.__str__()

    def set_user_vote_is_bad(self, user, is_bad):
        if is_bad:
            self.bad_votes.add(user)
            if user in self.good_votes:
                self.good_votes.remove(user)
        else:
            self.good_votes.add(user)
            if user in self.bad_votes:
                self.bad_votes.remove(user)

    def calc_user_score(self):
        if not self.good_votes and not self.bad_votes:
            return None
        else:
            return len(self.good_votes) / (len(self.good_votes) + len(self.bad_votes))

    def calc_label(self):
        return self.calc_user_score() > 0.5

    def votes_number(self):
        return len(self.good_votes) + len(self.bad_votes)

    def is_controversal(self):
        return self.votes_number() > 0 and \
            len(self.good_votes) * len(self.bad_votes) != 0


def dataset_item_to_dict(item):
    return dict(key = item.key._asdict(),
                markup_stat = item.markup_stat._asdict(),
                polygon = item.polygon,
                metric_score = item.metric_score._asdict())


def dataset_item_from_dict(d):
    return DatasetItem(MarkupTaskKey(**d["key"]),
                       MarkupResultStat(**d['markup_stat']),
                       d['polygon'],
                       MetricScore(**d['metric_score']))


def load_markup_polygons(markup_tasks_dir):
    """return dict: key -> polygon"""
    res = {}
    for fn in os.listdir(markup_tasks_dir):
        task = fn
        with open(os.path.join(markup_tasks_dir, fn)) as f:
            markup_json = json.load(f)
            for obj in markup_json:
                object_id = obj["id"]
                polygon = obj["coords"]
                polygon_flat = []
                for coord in polygon:
                   polygon_flat.append(coord[0])
                   polygon_flat.append(coord[1])
                polygon_flat += polygon_flat[:2]
                res[MarkupTaskKey(task, object_id)] = polygon_flat

    return res


def parse_results(results_json_path):
    """parses @results_json_path file and returns [MarkupResult]"""
    with open(results_json_path, 'r') as results_file:
        results_json = json.load(results_file)

    markup_results = []

    for session_json in results_json:
        task = session_json["task"]
        user = session_json["user"]
        labels = {}
        if user in IGNORE_USERS_LIST:
            continue
        for result in session_json["results"]:
            object_id = result["id"]
            labels[object_id] = result["isBad"]

        for object_id, is_bad in labels.items():
            markup_results.append(
                MarkupResult(
                    MarkupTaskKey(task, object_id),
                    user,
                    is_bad)
            )

    return markup_results


def load_metric_scores(file_path):
    """return dict: task, object_id -> (float [0, 1], polygon)"""
    d = dict()
    with open(file_path) as f:
        for l in f:
            task, obj_id, iou, score, residual, shift_x, shift_y, theta, scale, points_diff, *polygon = l.split()
            polygon = list(map(float, polygon))
            polygon += polygon[:2]
            d[MarkupTaskKey(task, obj_id)] = \
                MetricScore(float(score), float(iou), float(residual),
                            float(shift_x), float(shift_y), float(theta),
                            float(scale), int(points_diff), polygon)
    return d



def collect_markup_stats(results):
    """reduces [MarkupResult] to dict:MarkupTaskKey->MarkupResultStat"""

    def reduce_step(stat_dict, result):
        stat_dict[result.task_key].set_user_vote_is_bad(result.user, result.is_bad)
        return stat_dict

    return reduce(reduce_step, results, defaultdict(MarkupResultStat))


def task_to_dataset(task):
    """bld_1_edge20.json -> '1_edge'"""
    m = re.match('bld_(\d+)_(edge|rectangle)(\d+).json', task)
    if not m:
        return "None"
    gr = m.groups()
    return "{0}_{1}".format(gr[0], gr[1])


def split_to_datasets(markup_results):
    """returns dict:dataset_name -> [MarkupResult]"""
    d = defaultdict(list)
    for r in markup_results:
        d[task_to_dataset(r.task_key.task)].append(r)

    return d


def dump_markup_results_to_easyview(result_stat_iter, task_polygons, fobj):
    """result_stat_iter: iter: (MarkupTaskKey, MarkupResultStat)
       task_polygons: dict: MarkupTaskKey -> polygon
    """
    GOOD_STYLE = "!linestyle=green:2"
    BAD_STYLE = "!linestyle=red:2"
    CONT_STYLE = "!linestyle=blue:2"
    delim = '|'
    writer = csv.writer(fobj, delimiter=delim)
    for task_key, result_stat in result_stat_iter:
        if result_stat.is_controversal():
            writer.writerow([CONT_STYLE])
        elif result_stat.good_votes:
            writer.writerow([GOOD_STYLE])
        else:
            writer.writerow([BAD_STYLE])
        desc = "{0}:{1} good={2} bad={3}".format(
                task_key.task, task_key.object_id, result_stat.good_votes,
                result_stat.bad_votes
            )
        writer.writerow(task_polygons[task_key] + [desc])


def dump_markup_results_with_ref_to_easyview(result_stat_iter, task_polygons, scores, fobj):
    delim = '|'
    writer = csv.writer(fobj, delimiter=delim)
    for task_key, result_stat in result_stat_iter:
        writer.writerow(["!linestyle=red:2"])
        desc = "{0}:{1} good={2} bad={3} score={4}".format(
                task_key.task, task_key.object_id, result_stat.good_votes,
                result_stat.bad_votes, scores[task_key].score
            )
        writer.writerow(task_polygons[task_key] + [desc])
        writer.writerow(["!linestyle=blue:2"])
        writer.writerow(scores[task_key].polygon + [desc])


def save_controversal_results_to_easyview(results, markup_tasks_dir, output_dir):
    task_polygons = load_markup_polygons(markup_tasks_dir)
    stats = collect_markup_stats(results)
    controversal_stats = filter(lambda kv: kv[1].is_controversal(), stats.items())
    ds_to_stats = defaultdict(list)
    for task_key, stat in  controversal_stats:
        ds_to_stats[task_to_dataset(task_key.task)].append((task_key, stat))


    for ds, stats in ds_to_stats.items():
        output_filename = "controversal_{0}.ev".format(ds)
        path = os.path.join(output_dir, output_filename)
        with open(path, 'w') as f:
            dump_markup_results_to_easyview(stats, task_polygons, f)
        print("saved to file {0}".format(os.path.abspath(path)))


def save_results_to_easyview(results_path, markup_tasks_dir, output_dir):
    task_polygons = load_markup_polygons(markup_tasks_dir)
    results = parse_results(results_path)
    stats = collect_markup_stats(results)
    ds_to_stats = defaultdict(list)
    for task_key, stat in stats.items():
        ds_to_stats[task_to_dataset(task_key.task)].append((task_key, stat))

    for ds, stats in ds_to_stats.items():
        output_filename = "markup_{0}.ev".format(ds)
        path = os.path.join(output_dir, output_filename)
        with open(path, 'w') as f:
            dump_markup_results_to_easyview(stats, task_polygons, f)
        print("saved file {0}".format(path))


def analyze_results(results):
    print("Total markup results: {0}".format(len(results)))
    analyze_per_user_stats(results)
    analyze_per_task_stats(results)


def analyze_per_user_stats(markup_results):
    stats = collect_markup_stats(markup_results)

    good = defaultdict(int)
    bad = defaultdict(int)
    correct = defaultdict(int)
    total = defaultdict(int)
    for r in markup_results:
        total[r.user] += 1
        if r.is_bad:
            bad[r.user] += 1
        else:
            good[r.user] += 1
        if (stats[r.task_key].calc_user_score() < 0.5) == r.is_bad:
            correct[r.user] += 1

    print("user\t\t\t\tgood\tbad\tprecision")
    print("-" * 80)
    for user in set(chain(good.keys(), bad.keys())):
        print("{0:25s}\t{1}\t{2}\t{3:.2f}".format(
            user, good[user], bad[user], correct[user] / total[user]))


def analyze_per_task_stats(markup_results):
    stats = collect_markup_stats(markup_results)
    print("Total tasks {0}".format(len(stats)))

    labels = [stats[k].calc_user_score() > 0.5 for k in stats]
    counter = Counter(labels)
    print("Labels stats: {0:.2f} are good {1}".format(
            counter[True] / len(labels), counter))

    MIN_VOTES = 2
    tasks_with_min_votes = 0
    total_quorum_tasks = 0

    for key, value in stats.items():
        if value.votes_number() >= MIN_VOTES:
            tasks_with_min_votes += 1

            if not value.is_controversal():
                total_quorum_tasks += 1

    print("Total tasks with at least {0} votes: {1}".format(MIN_VOTES, tasks_with_min_votes))
    print("Total tasks with agreement: {0:.2f} {1}"\
        .format(
            total_quorum_tasks / tasks_with_min_votes if tasks_with_min_votes else 0.,
            total_quorum_tasks))



def load_labels(results):
    return {(r.task, r.object_id): not r.is_bad for r in results}


def find_intersections(file_path):
    results = parse_results(file_path)
    d = defaultdict(list)
    for r in results:
        d[r.object_id].append(r)

    for x in islice((v for k, v in d.items() if len(v) > 1), 10):
        print(x)


def process(labels_path, scores_path):
    results = parse_results(labels_path)
    dataset_results = split_to_datasets(results)
    scores = load_metric_scores(scores_path)
    print("Loaded {0} scores".format(len(scores)))

    print("##Process all")
    process_results(results, scores)

    for dataset, ds_results in dataset_results.items():
        print("##Process {0}".format(dataset))
        process_results(ds_results, scores)


def process_kural(labels_path, scores_path):
    results = list(filter(lambda o: not(o.task_key.task.startswith('task') and o.is_bad == True),
                          parse_results(labels_path)))
    stats = collect_markup_stats(results)
    scores = load_metric_scores(scores_path)
    print("Loaded {0} scores".format(len(scores)))

    good_labels_scores = [scores[k].score for k, v in stats.items() if v.calc_label() and k in scores]
    bad_labels_scores = [scores[k].score for k, v in stats.items() if not v.calc_label() and k in scores]
    ax = sns.distplot(good_labels_scores, color="green", label="good objects scores")
    sns.distplot(bad_labels_scores, ax=ax, color="red", label="bad objects scores")
    plt.legend()
    plt.show()

    process_results(stats, scores)


def find_best_threshold(scores, labels):
    precision, recall, thresholds = precision_recall_curve(labels, scores)
    auprc  = auc(recall, precision)
    max_f1 = 0
    max_f1_threshold = 0
    for r, p, t in zip(recall, precision, thresholds):
        if p + r == 0: continue
        pred_labels = [s > t for s in scores]
        if (2*p*r)/(p + r) > max_f1:
            max_f1 = (2*p*r)/(p + r)
            max_f1_threshold = t

    accuracies = []
    thresholds = np.arange(0,1,0.05)
    for threshold in thresholds:
        y_pred = np.greater(scores, threshold).astype(int)
        accuracy = accuracy_score(labels, y_pred)
        accuracies.append(accuracy)

    accuracies = np.array(accuracies)
    max_accuracy = accuracies.max()
    max_accuracy_threshold =  thresholds[accuracies.argmax()]
    print("Max F1 = {0:.2f}, threshold = {1:.2f}".format(max_f1, max_f1_threshold))
    print("Max accuracy = {0:.2f}, threshold = {1:.2f}".format(max_accuracy, max_accuracy_threshold))

    THRESHOLD = 0.4
    print("For throshold = {0:.2f} accuracy = {1:.2f}".format(
        THRESHOLD, accuracy_score(labels, np.greater(scores, THRESHOLD).astype(int))))


def process_results(stats, scores):
    common = set(stats.keys()).intersection(scores.keys())
    print("There are {0} objects in intersection".format(len(common)))

    common_labels = [stats[k].calc_user_score() > 0.5 for k in common]
    labels_counter = Counter(common_labels)
    print("Labels stats: {0:.2f} are good {1}".format(
            labels_counter[True] / len(common_labels), labels_counter))
    common_scores = [scores[k].score if k in scores else 0. for k in common]
    common_predict_labels = [s < 0.72 for s in common_scores]

    fpr, tpr, _ = roc_curve(common_labels, common_scores)
    roc_auc = auc(fpr, tpr)
    print("roc_auc = {0:.2f}".format(roc_auc))

    find_best_threshold(common_scores, common_labels)

    plt.figure()
    lw = 2
    plt.plot(fpr, tpr, color='darkorange',
            lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
    plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic example')
    plt.legend(loc="lower right")
    plt.show()


def visualize_metric_errors(results, scores, markup_tasks_dir, output_dir):
    task_polygons = load_markup_polygons(markup_tasks_dir)
    ds_results_dict = split_to_datasets(results)
    markup_polygons = load_markup_polygons(markup_tasks_dir)

    for ds, ds_resuls in ds_results_dict.items():
        false_pos, false_neg = get_top_errors(ds_resuls, scores)
        false_pos_fn = "false_positive_{0}.ev".format(ds)
        false_pos_path = os.path.join(output_dir, false_pos_fn)
        if false_pos:
            with open(false_pos_path, 'w') as f:
                dump_markup_results_with_ref_to_easyview(false_pos, task_polygons, scores, f)
                print("wrote {0}".format(false_pos_path))

        if false_neg:
            false_neg_fn = "false_negative_{0}.ev".format(ds)
            false_neg_path = os.path.join(output_dir, false_neg_fn)
            with open(false_neg_path, 'w') as f:
                dump_markup_results_with_ref_to_easyview(false_neg, task_polygons, scores, f)
                print("wrote {0}".format(false_neg_path))


def get_top_errors(results, scores):
    METRIC_TRESHOLD = 0.4
    results = filter(lambda r: r.task_key in scores, results)
    stats = collect_markup_stats(results)
    false_pos = []
    false_neg = []

    for key, stat in stats.items():
        disc = stat.calc_user_score() - scores[key].score

        if stat.calc_user_score() < 0.5 and scores[key].score > 0.4:
            false_pos.append((key, stat))
        elif stat.calc_user_score() > 0.5 and scores[key].score < 0.4:
            false_neg.append((key, stat))

    return false_pos, false_neg


def format_easyview_polygon(polygon):
    return ' '.join(map(str, polygon)) + " " + str(polygon[0]) + " " + str(polygon[1])


def format_easyview(rs):
    text = "{0}_{1}_good_votes={2}_bad_votes={3}_score={4}"\
        .format(rs.task, rs.object_id, ','.join(rs.good_votes), ','.join(rs.bad_votes), rs.auto_score)

    res_str = "!linestyle=red:2\n"
    res_str += format_easyview_polygon(rs.polygon) + " " + text + "\n"
    res_str += "!linestyle=blue:2\n"
    res_str += format_easyview_polygon(rs.ref_polygon) + " " + text + "\n"
    return res_str


def print_votes_on_task(results_path, task, object_id):
    for r in parse_results(results_path):
        if r.task_key == (task, object_id):
            print("{0}:{1} {2}\tis_bad={3}".format(task, object_id, r.user, r.is_bad))



def save_dataset(dataset, path):
    with open(path, 'w') as f:
        json.dump(list(map(dataset_item_to_dict, dataset)), f)


def save_datasets(full, train, val, test, output_dir):
    print("Saving datasets to {0}".format(output_dir))
    save_dataset(full, os.path.join(output_dir, "dataset.json"))
    save_dataset(train, os.path.join(output_dir, "train_dataset.json"))
    save_dataset(val, os.path.join(output_dir, "val_dataset.json"))
    save_dataset(test, os.path.join(output_dir, "test_dataset.json"))


def make_dataset(stats, scores, markup_tasks_dir):
    markup_polygons = load_markup_polygons(markup_tasks_dir)

    common_keys = list(
        set(stats.keys()).intersection(
            set(scores.keys())).intersection(
                set(markup_polygons.keys()))
                )
    ds_size = len(common_keys)
    print("There are {0} total items".format(ds_size))
    dataset = [DatasetItem(k, stats[k], markup_polygons[k], scores[k]) for k in common_keys]

    print("Shuffling dataset")
    random.shuffle(dataset)

    TRAIN_RATIO, VAL_RATIO = 0.8, 0.1
    train_idx = int(ds_size * TRAIN_RATIO)
    val_idx = train_idx + int(ds_size * VAL_RATIO)
    train_dataset = dataset[:train_idx]
    val_dataset = dataset[train_idx: val_idx]
    test_dataset = dataset[val_idx:]

    return dataset, train_dataset, val_dataset, test_dataset


def score_to_features(s):
    return s.iou, s.residual, s.shift_x, s.shift_y, s.theta, s.scale, s.points_diff



def train_catboost(dataset):
    features = [score_to_features(i.metric_score) for i in dataset]
    labels = [i.markup_stat.calc_label() for i in dataset]

    model = catboost.CatBoostClassifier(
        iterations=3000,
        depth=6,
        learning_rate=1.)

    print("Start training")
    model.fit(features, labels)
    print("Done")

    return model


def calc_catboost_feature_importance(model, dataset):
    features = [score_to_features(i.metric_score) for i in dataset]
    labels = [i.markup_stat.calc_label() for i in dataset]
    return model.get_feature_importance(features, labels)


def test_catboost(model, dataset):
    features = [score_to_features(i.metric_score) for i in dataset]
    labels = [i.markup_stat.calc_label() for i in dataset]

    conf_true = [p[1] for p in model.predict_proba(features)]

    find_best_threshold(conf_true, labels)

    orig_score = [r.metric_score.score for r in dataset]

    fpr, tpr, _ = roc_curve(labels, conf_true)
    roc_auc = auc(fpr, tpr)
    print("catboost roc_auc = {0}".format(roc_auc))

    ofpr, otpr, _ = roc_curve(labels, orig_score)
    oroc_auc = auc(ofpr, otpr)
    print("orig roc_auc = {0}".format(oroc_auc))

    good_labels_scores = []
    bad_labels_scores = []
    for i in range(len(labels)):
        if labels[i]:
            good_labels_scores.append(conf_true[i])
        else:
            bad_labels_scores.append(conf_true[i])
    ax = sns.distplot(good_labels_scores, color="green", label="good objects scores")
    sns.distplot(bad_labels_scores, ax=ax, color="red", label="bad objects scores")
    plt.legend()
    plt.show()

    plt.figure()
    lw = 2
    plt.plot(fpr, tpr, color='darkorange',
            lw=lw, label='ROC curve (area catboost=%0.2f)' % roc_auc)
    plt.plot(ofpr, otpr, color='red',
            lw=lw, label='ROC curve (area orig=%0.2f)' % oroc_auc)
    plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic example')
    plt.legend(loc="lower right")
    plt.show()


def filter_results(results_path, keys, output_path):

    def filter_session(session_json, keys):
        session_json["results"] = [r for r in session_json["results"]
                if MarkupTaskKey(session_json["task"], r["id"]) in keys]
        return session_json

    tasks = set(r.task for r in keys)
    keys = set(keys)

    with open(results_path) as f:
        results_json = json.load(f)

    filtered_results = [filter_session(s, keys) for s in results_json
            if s["task"] in tasks and s["user"] not in IGNORE_USERS_LIST]

    with open(output_path, 'w') as f:
        print("writing filtered results to {0}".format(output_path))
        json.dump(filtered_results, f)
