#!/usr/bin/env python
# coding: utf-8


import sys
reload(sys)
sys.setdefaultencoding("utf-8")

import random
import numpy as np
import itertools
import functools
import multiprocessing
from multiprocessing.dummy import Pool
from collections import defaultdict

from swiss import Url, Game, Tournament
from helpers import doc_key, query_key, aggregate_round_results


def calc_pscore(url):
    return url.data['pscore']


def calc_dbd_score(url):
    return (url.points_won * 5000 +  # 1) points won
            url.get_buchholz_coef() * 200 +  # 2) Buchholz coef
            int(url.metric * 50))  # 3) metric score

def calc_dbd_score_parametrized(toloka_w, buch_w, metric_w, url):
    return (url.points_won * toloka_w +  # 1) points won
            url.get_buchholz_coef() * buch_w +  # 2) Buchholz coef
            int(url.metric * metric_w))  # 3) metric score


def iter_dbd_weights():
    return itertools.product(
        [0, 1, 5, 10, 20, 30, 40, 50, 75, 100, 200, 300, 350, 400, 450, 500, 550, 600, 700, 750], # toloka_points score
        [0, 1, 2, 3, 5, 6, 7, 7.5, 8, 9, 10, 11, 12, 12.5, 13, 14, 15, 20, 25, 30, 40, 50],  # get_buchholz_coef
        [0, 0.1, 0.2, 0.5, 0.75, 0.8, 0.9, 1, 1.1, 1.2, 1.25, 1.3, 1.4, 1.5, 1.75, 2, 2.5, 3, 5],  # metric_score
    )


def calc_query_metrics(group, toloka_agg, max_rounds):
    query = group[0]['query'].decode('utf-8')
    region_id = group[0]['region_id']

    pairs_correct = defaultdict(int)
    pairs_total = defaultdict(int)

    urls = {doc_key(item): Url(item, i) for i, item in enumerate(group)}

    # print '\n\n============================================================'
    print 'Query: {}, region {}, items {}'.format(query.decode('utf-8'), region_id, len(group))

    # Создаём турнир
    t = Tournament()
    t.set_urls(urls.values())

    games = []

    for swiss_round in range(1, max_rounds + 1):
        games_played, wins = aggregate_round_results(toloka_agg[swiss_round][query_key(group[0])])

        try:
            # Создаём игры текущего раунда
            for i, (left_key, right_key) in enumerate(games_played.keys()):
                game = Game(urls[left_key], urls[right_key])
                game.add_score(wins[left_key], wins[right_key])
                game.round = swiss_round
                game.id = i
                games.append(game)
        except:
            pass

        # print 'Round {}, games played {}'.format(swiss_round, len(games_played))
        t.set_games(games)

    for swiss_round in range(1, max_rounds + 1):
        # Аггрегируем результаты матча по разметке толоки
        round_results = defaultdict(int)
        for task in toloka_agg[swiss_round][query_key(group[0])]:
            inputs = task['inputValues']
            left_key = doc_key(inputs, 'left_')
            right_key = doc_key(inputs, 'right_')
            is_chosen = task['outputValues']['label']
            won_key = left_key if is_chosen == 'left' else right_key
            lost_key = left_key if is_chosen == 'right' else right_key
            round_results[(won_key, lost_key)] += 1

    for tup in iter_dbd_weights():
        toloka_w, buch_w, metric_w = tup
        sort_name = 'dbd_score_{}_{}_{}'.format(toloka_w, buch_w, metric_w)
        sort_fn = functools.partial(calc_dbd_score_parametrized, toloka_w, buch_w, metric_w)

        sorted_by_dbd = Tournament.sort_urls(t.urls, sort_fn)

        # for WIN_POINTS in [4, 5, 6]:
        for won_key, lost_key in round_results.keys():
            try:
                if round_results[(won_key, lost_key)] >= 4:
                    won = t.urls[t.urls.index(won_key)]
                    lost = t.urls[t.urls.index(lost_key)]

                    pairs_total[sort_name] += 1
                    if sorted_by_dbd.index(won) < sorted_by_dbd.index(lost):
                        pairs_correct[sort_name] += 1
            except:
                pass
                # print 'Exception {} {}, won_key {}, lost_key {}'.format(query, region_id, won_key, lost_key)

    return pairs_correct, pairs_total


def main(*args):
    queries_grouped, toloka_data, in3, token, any_param, html_file = args

    max_rounds = 0
    toloka_agg = {}
    for task in toloka_data:
        swiss_round = task['round']
        max_rounds = max(max_rounds, swiss_round)
        if swiss_round not in toloka_agg:
            toloka_agg[swiss_round] = defaultdict(list)
        toloka_agg[swiss_round][  query_key(task['inputValues'])  ].append(task)

    pairs_correct = defaultdict(int)
    pairs_total = defaultdict(int)

    processes = multiprocessing.cpu_count() - 1 or 1
    print 'processes {}'.format(processes)
    pool = Pool(processes=processes)

    fn = functools.partial(calc_query_metrics, toloka_agg=toloka_agg, max_rounds=max_rounds)
    for correct, total in pool.map(fn, queries_grouped):
        for sort_name, value in correct.items():
            pairs_correct[sort_name] += value
        for sort_name, value in total.items():
            pairs_total[sort_name] += value

    max_total_predict = 0
    max_total_predict_name = ''
    for tup in iter_dbd_weights():
        print ''
        toloka_w, buch_w, metric_w = tup
        sort_name = 'dbd_score_{}_{}_{}'.format(toloka_w, buch_w, metric_w)
        # for WIN_POINTS in [4, 5, 6]:
        predict = float(pairs_correct[sort_name]) / pairs_total[sort_name] * 100
        print 'Score {} predicts {:0.2f}% toloka pairs with 4-2 and more points'.format(sort_name, predict)
        if max_total_predict < predict:
            max_total_predict = predict
            max_total_predict_name = sort_name

    print '\nmax_total_predict {} by {}'.format(max_total_predict, max_total_predict_name)

    # print 'For all queries, correct pairs {} / {} = {:0.2f}%'.format(pairs_correct, pairs_total, float(pairs_correct) / pairs_total * 100)

    return []
