# -*- coding: utf-8 -*-

import sys
import math
import random
import traceback
import numpy as np
import itertools
from urlparse import urlparse
from mwmatching import maxWeightMatching
from scipy.stats import spearmanr
from scipy.stats import pearsonr

HIGH_COST = 0


class Url:
    points_won = 0
    points_lost = 0
    games_won = 0
    games_lost = 0
    games_draw = 0
    points_opponents = 0

    def __init__(self, row, index):
        self.qid = row['qid']
        self.query = row['query']
        self.url = row['image_url']
        self.index = index
        self.metric = row['metric']
        self.data = row
        self.matches = []

    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return self.query == other.query and self.url == other.url and self.metric == other.metric
        elif isinstance(other, tuple) and len(other) == 4:
            return self.query == other[0] and self.url == other[2] and self.data['page_url'] == other[3]
        else:
            return False

    def score(self, index=0, length=0):
        return (self.points_won * 5000 +           # 1) points won
                self.get_buchholz_coef() * 200 +   # 2) Buchholz coef
                int(self.metric * 50) +            # 3) metric score
                (length - index) * 10)             # 4) by index

    def get_toloka_wins_histogram(self):
        wins = {x: 0 for x in range(7)}
        # print 'matches: {}'.format([vars(x) for x in self.matches])
        for game in self.matches:
            if self == game.team1:
                wins[game.team1_score] += 1
            else:
                wins[game.team2_score] += 1
        return wins

    def get_buchholz_coef(self):
        coef = 0
        for game in self.matches:
            if self == game.team1:
                coef += game.team2.points_won
            else:
                coef += game.team1.points_won
        return coef

    def add_score(self, game, score, opp_score, opp_total_points_won):
        # print '  add score', score, 'for team:', self.index, '  inc matches', len(self.matches)
        self.points_won += score
        self.points_lost += opp_score
        self.points_opponents += opp_total_points_won
        self.matches.append(game)

        if score == opp_score:
            self.games_draw += 1
        elif score > opp_score:
            self.games_won += 1
        else:
            self.games_lost += 1

    def __repr__(self):
        return 'URL-{}'.format(self.index)

    def toJSON(self):
        return self.data


class Game:
    id = None
    round = None
    team1 = None
    team2 = None
    team1_score = None
    team2_score = None

    # state:
    # -1 - undefined
    # 0 - equal
    # 1 - win 1
    # 2 - win 2
    state = -1

    def __init__(self, url1, url2):
        self.team1 = url1
        self.team2 = url2
        pass

    def __repr__(self):
        sep = 'vs'
        if self.state == 0:
            sep = '='
        elif self.state == 1:
            sep = '>'
        else:
            sep = '<'

        return 'Game {}/{}: {} {} {}'.format(
            self.id, self.round, self.team1.index, sep, self.team2.index
        )

    def add_score(self, team1_score, team2_score):
        self.team1_score = team1_score
        self.team2_score = team2_score
        if team1_score == team2_score:
            self.state == 0
        else:
            self.state = 1 if team1_score > team2_score else 2
        # print ' add score', self.state, 'for teams:', self.team1.index, self.team2.index
        won1 = self.team1.points_won
        won2 = self.team2.points_won
        self.team1.add_score(self, team1_score, team2_score, won2)
        self.team2.add_score(self, team2_score, team1_score, won1)

    def toJSON(self):
        return {
            'team1': self.team1.toJSON(),
            'team2': self.team2.toJSON()
        }


class Tournament:

    def __init__(self, html=None):
        self.urls = []
        self.games = []
        self.cached_sorts = {}

    def set_urls(self, urls):
        self.urls = urls

    def set_games(self, games):
        self.games = games

    def build_round(self, is_swiss=True):
        if is_swiss:
            new_games = self.build_swiss_round()
        else:
            new_games = self.build_random_round()
        return new_games

    def build_swiss_round(self):
        #max_pairs = int(len(self.urls) / 2)
        ranking = sorted(self.urls, key=lambda x: x.score(), reverse=True)

        cost_matrix = self.build_cost(ranking)
        #print 'Ranking: {}'.format(ranking)
        #print 'Cost matrix:'
        #print cost_matrix
        new_games = self.build_pairing(ranking, cost_matrix)
        #self.print_round(new_games1)
        #print ''

        return new_games

    def build_pairing(self, urls, cost_matrix):
        cost_list = []
        for i, row in enumerate(cost_matrix):
            for j, cost in enumerate(row):
                if i != j and cost != HIGH_COST:
                    cost_list.append((i, j, cost))

        # [5, 9, 8, 7, 10, 0, -1, 3, 2, 1, 4]
        pairs = maxWeightMatching(cost_list, maxcardinality=True)
        #print '\npairs:', pairs
        used = np.full((len(pairs)), False, np.array(False).dtype)

        new_games = []
        for i1, i2 in enumerate(pairs):
            if not used[i1] and not used[i2]:
                used[i1] = True
                used[i2] = True

                game = Game(urls[i1], urls[i2])
                game.id = 77
                game.round = 2
                new_games.append(game)

                #print 'pair: {:<2}) {:<6}   vs    {:<2}) {:<6}'.format(
                #    i1, urls[i1], i2, urls[i2]
                #)
        return new_games


    def build_cost(self, urls):
        length = len(urls)
        cost_matrix = np.full((length, length), -1, np.array(-1).dtype)

        for i, url1 in enumerate(urls):
            for j, url2 in enumerate(urls):
                if i == j or self.already_played(url1, url2):
                    cost_matrix[i][j] = cost_matrix[j][i] = HIGH_COST
                else:
                    cost_matrix[i][j] = cost_matrix[j][i] = -abs(url1.score(i, length) - url2.score(j, length))
        return cost_matrix

    def already_played(self, url1, url2):
        for game in url1.matches:
            if game.team1 == url2 or game.team2 == url2:
                return True
        return False

    def build_random_round(self):
        max_pairs = int(len(self.urls) / 2)
        # random.shuffle(self.urls)

        new_games = []
        total_urls = len(self.urls)
        mid = total_urls / 2
        for i in range(0, mid):
            # print i, mid+i
            game = Game(self.urls[i], self.urls[mid+i])
            game.id = i
            game.round = 1
            new_games.append(game)

        return new_games

    def play_random_games(self, games):
        print 'Play {} games'.format(len(games))
        for game in games:
            win1 = np.random.choice([0, 1, 2, 3, 4], p=[10. / 100, 15. / 100, 20. / 100, 30. / 100, 25. / 100])

            if game.team1.index > game.team2.index:
                win2 = 4 - win1
            else:
                win2 = win1
                win1 = 4 - win2

            print 'play', game.team1.index, game.team2.index, win1, win2
            game.add_score(win1, win2)

        self.games.append(games)

    @staticmethod
    def sort_urls(urls, sort, return_as_is=False):
        urls_with_boosts = []
        for url in urls:
            value = None
            try:
                value = sort(url)
                # if value is None:
                #     print '  Query: {}, For url {}, boost {} is None'.format(url.data['query'], url, sort.__name__)
            except Exception as e:
                sys.stderr.write('    Query: {}, For url {}, calc boost {} exception: {}\n'.format(url.data['query'], url, sort.__name__, str(e)))
                traceback.print_exc()
                # print '  Query: {}, For url {}, calc boost {} exception: {}'.format(url.data['query'], url, sort.__name__, str(e))
                pass
            if value is not None:
                urls_with_boosts.append((url, value))
        urls_sorted = sorted(random.sample(urls_with_boosts, len(urls_with_boosts)), key=lambda x: x[1], reverse=True)
        if return_as_is:
            return urls_sorted
        else:
            return [x[0] for x in urls_sorted]

    def cache_sort(self, sort_fn):
        self.cached_sorts[sort_fn.__name__] = Tournament.sort_urls(self.urls, sort_fn)

    def get_sort_pairs_count(self, sort_fn, thr=-100):
        sorted_items = self.cached_sorts[sort_fn.__name__]
        pairs_count = 0
        pairs_ok = 0
        idx = 0
        for pair in itertools.combinations(sorted_items, r=2):
            pairs_count += 1
            if pair[0].points_won - pair[1].points_won < thr:
                # выкидываем неуверенную пару
                continue

            pairs_ok += 1
        return pairs_ok, pairs_count

    def compare_sorts_by_strong_pairs(self, sort1, sort2, thr=-1000):
        sorted_by_sort1 = self.cached_sorts[sort1.__name__]
        sorted_by_sort2 = self.cached_sorts[sort2.__name__]

        pairs_count = 0
        pairs_correct = 0
        pairs_missed = 0
        pairs_unsure = 0
        idx = 0
        for pair in itertools.combinations(sorted_by_sort1, r=2):
            if pair[0] not in sorted_by_sort2 or pair[1] not in sorted_by_sort2:
                pairs_missed += 1
                continue
            if pair[0].data.get('dbd_score', 0) - pair[1].data.get('dbd_score', 0) < thr:
                # выкидываем неуверенную пару
                pairs_unsure += 1
                continue

            idx += 1
            if sorted_by_sort2.index(pair[0]) < sorted_by_sort2.index(pair[1]):
                pairs_correct += 1
            pairs_count += 1
        return pairs_correct, pairs_count, pairs_missed, pairs_unsure

    def compare_dt_score_by_strong_pairs(self, sort1, dt_score_key, thr=-1000):
        sorted_by_sort1 = self.cached_sorts[sort1.__name__]

        pairs_count = 0
        pairs_correct = 0
        pairs_missed = 0
        pairs_unsure = 0
        idx = 0
        for pair in itertools.combinations(sorted_by_sort1, r=2):
            if pair[0].data.get('dbd_score', 0) - pair[1].data.get('dbd_score', 0) < thr:
                # unsure
                continue

            idx += 1
            if pair[0].data.get(dt_score_key) > pair[1].data.get(dt_score_key):
                pairs_correct += 1
            elif abs(pair[1].data.get(dt_score_key) - pair[0].data.get(dt_score_key)) < 0.0001:
                pairs_correct += 0.5
                # считаем как почти уверенную пару
                pairs_unsure += 1
            pairs_count += 1
        return pairs_correct, pairs_count, pairs_missed, pairs_unsure

    # p-score
    # bradley terry,
    # ранговая корреляция
    def compare_sorts(self, sort1, sort2, by='pairs', param=None):
        """Сравнивает сортировку sort1 и sort2 по заданной метрике"""

        if by == 'pairs':
            sorted_by_sort1 = Tournament.sort_urls(self.urls, sort1)
            sorted_by_sort2 = Tournament.sort_urls(self.urls, sort2)

            pairs_count = 0
            correct_pairs = 0
            for pair in itertools.combinations(sorted_by_sort1, r=2):
                if pair[0] not in sorted_by_sort2 or pair[1] not in sorted_by_sort2:
                    continue
                if sorted_by_sort2.index(pair[0]) < sorted_by_sort2.index(pair[1]):
                    correct_pairs += 1
                pairs_count += 1
            return correct_pairs, pairs_count

        elif by == 'strong_pairs':
            sorted_by_sort1 = Tournament.sort_urls(self.urls, sort1)
            sorted_by_sort2 = Tournament.sort_urls(self.urls, sort2)

            pairs_count = 0
            correct_pairs = 0
            idx = 0
            for pair in itertools.combinations(sorted_by_sort1, r=2):
                if pair[0] not in sorted_by_sort2 or pair[1] not in sorted_by_sort2:
                    continue
                if pair[0].points_won - pair[1].points_won < param:
                    # выкидываем неуверенную пару
                    continue

                idx += 1
                if sorted_by_sort2.index(pair[0]) < sorted_by_sort2.index(pair[1]):
                    correct_pairs += 1
                pairs_count += 1
            # print 'idx {}'.format(idx)
            return correct_pairs, pairs_count

        elif by == 'mae':
            sorted_by_sort1 = Tournament.sort_urls(self.urls, sort1)
            sorted_by_sort2 = Tournament.sort_urls(self.urls, sort2)
            urls_count = float(min(len(sorted_by_sort1), len(sorted_by_sort2)))

            error = 0
            for idx, item in enumerate(sorted_by_sort1):
                if item not in sorted_by_sort2:
                    continue
                error += abs(idx - sorted_by_sort2.index(item)) / float(idx + 1)

            return error / urls_count

        elif by == 'mse':
            sorted_by_sort1 = Tournament.sort_urls(self.urls, sort1)
            sorted_by_sort2 = Tournament.sort_urls(self.urls, sort2)
            urls_count = float(min(len(sorted_by_sort1), len(sorted_by_sort2)))

            error = 0
            for idx, item in enumerate(sorted_by_sort1):
                if item not in sorted_by_sort2:
                    continue
                error += (idx - sorted_by_sort2.index(item)) ** 2 / float(idx + 1)

            return error / urls_count


        elif by == 'spearmanr':
            sorted_by_sort1 = Tournament.sort_urls(self.urls, sort1)
            sorted_by_sort2 = Tournament.sort_urls(self.urls, sort2)

            # sorted_by_sort1 = Tournament.sort_urls(self.urls, sort1, return_as_is=True)
            # sorted_by_sort2 = Tournament.sort_urls(self.urls, sort2, return_as_is=True)
            urls_count = min(len(sorted_by_sort1), len(sorted_by_sort2))

            # print '  calc spearmanr corr, "{}" ({}) - "{}" ({}) - count {}, '.format(sort1.__name__, len(sorted_by_sort1), sort2.__name__, len(sorted_by_sort2), urls_count)
            # print '    {}'.format([x.index for x in sorted_by_sort1][:urls_count])
            # print '    {}'.format([x.index for x in sorted_by_sort2][:urls_count])

            # return pearsonr([x[1] for x in sorted_by_sort1][:urls_count], [x[1] for x in sorted_by_sort2][:urls_count])
            return spearmanr([x.index for x in sorted_by_sort1][:urls_count], [x.index for x in sorted_by_sort2][:urls_count])

    # def compare_sort_with_dbd_by_spearmanr(self, sort_fn):
    #     # root mean squared error
    #     sorted_by_dbd = Tournament.sort_urls(self.urls, calc_dbd_score)
    #     sorted_by_metric = Tournament.sort_urls(self.urls, sort_fn)
    #
    #     return spearmanr([x.index for x in sorted_by_dbd], [x.index for x in sorted_by_metric])


    def print_round(self, games=None):
        if not games:
            games = self.games

        for game in games:
            if game.state > 0:
                print 'Game {}/{}, WIN {}\nTEAM1: {} — {}\nTEAM2: {} — {}\n'.format(
                    game.id, game.round, game.state,
                    game.team1.index, game.team1_score,
                    game.team2.index, game.team2_score,
                )
            else:
                print 'Game {}/{}\nTEAM1: {}\nTEAM2: {}\n'.format(
                    game.id, game.round, game.team1.index, game.team2.index
                )

    def print_net(self, to_stdout=True):
        html = ''
        idx = 0
        for url in sorted(self.urls, key=lambda x: x.score(), reverse=True):
            matches = ''
            for game in url.matches:
                if url == game.team1:
                    if game.state == 1:
                        matches += 'W <span class="clicker" data-index="{index}">{index:<2}</span>, '.format(index=game.team2.index)
                    elif game.state == 2:
                        matches += 'L <span class="clicker" data-index="{index}">{index:<2}</span>, '.format(index=game.team2.index)
                    else:
                        matches += '= <span class="clicker" data-index="{index}">{index:<2}</span>, '.format(index=game.team2.index)
                else:
                    if game.state == 2:
                        matches += 'W <span class="clicker" data-index="{index}">{index:<2}</span>, '.format(index=game.team1.index)
                    elif game.state == 1:
                        matches += 'L <span class="clicker" data-index="{index}">{index:<2}</span>, '.format(index=game.team1.index)
                    else:
                        matches += '= <span class="clicker" data-index="{index}">{index:<2}</span>, '.format(index=game.team1.index)
            stats_by_url = '{:<2}) <span class="clicker" data-index="{}">Url {:<2}</span>,   (W {:<2}, L {:<2}, = {:<2})   total_score {:<8} toloka_points_won {:<4} toloka_opp_scores {:<4} Corsa {:.2f}   GAMES: [{}]     {}'.format(
                idx, url.index, url.index, url.games_won, url.games_lost, url.games_draw, url.score(), url.points_won, url.get_buchholz_coef(), url.metric, matches[:-1], url.url
            )
            html += stats_by_url
            if to_stdout:
                print stats_by_url
            idx += 1
        return html

    @staticmethod
    def get_html_gallery_item(url, sorts=None, idx=None, saliency=False, extra_info=[]):
        """
        Возвращает html с описанием и факторами картинки
        :param url: - класс Url, содержащий картинку и информацию о сравнениях в толоке
        :param sorts: - доп. бусты, которые посчитать для этой картинки
        :param idx: - порядковый номер. Используется только для отображения
        :param saliency: - флаг того, чтобы показывать saliency_map для картинок при наличии
        :return: html код
        """
        matches = ''
        for game in url.matches:
            if url == game.team1:
                if game.state == 1:
                    matches += 'W <span class="clicker" data-index="{index}">{index:<2}</span>, '.format(
                        index=game.team2.index)
                elif game.state == 2:
                    matches += 'L <span class="clicker" data-index="{index}">{index:<2}</span>, '.format(
                        index=game.team2.index)
                else:
                    matches += '= <span class="clicker" data-index="{index}">{index:<2}</span>, '.format(
                        index=game.team2.index)
            else:
                if game.state == 2:
                    matches += 'W <span class="clicker" data-index="{index}">{index:<2}</span>, '.format(
                        index=game.team1.index)
                elif game.state == 1:
                    matches += 'L <span class="clicker" data-index="{index}">{index:<2}</span>, '.format(
                        index=game.team1.index)
                else:
                    matches += '= <span class="clicker" data-index="{index}">{index:<2}</span>, '.format(
                        index=game.team1.index)
        attr_sorts = ''
        metrics_values = ''
        for boost, fn in (sorts or {}).items():
            try:
                val = fn(url)
            except Exception as e:
                val = 0.0
                # sys.stderr.write('exception in calc {} for url: {}'.format(fn.__name__, url))
                # continue
            attr_sorts += ' data-{}="{}"'.format(boost, val)
            metrics_values += '{}={:0.4f}&#10;'.format(boost, val)
        wins = url.get_toloka_wins_histogram()
        # wins_sum = float(sum(wins.values()))
        parsed_uri = urlparse(url.data['page_url'])
        parsed_extra_rows = ''.join([ '<tr><td>{text}</td></tr>'.format(text=text) for text in extra_info ])
        return """<div class="gallery__item" data-index="{url}" {attr_sorts}>
                <table class="gallery__table">
                    <tr><td> <b>{idx:<2})</b> <span class="clicker" data-index="{url}">Url {url:<2}</span> (Stats: Wins {win:<2}, Loses {loss:<2}, Draws {eq:<2}) </td></tr>
                    <tr><td> Games: {games} </td></tr>
                    <tr><td> Toloka points_won: {pwon:<2}, wins: {toloka_wins}, <span class="boosts_cut" title="{metrics_values}">Boosts</span> </td></tr>
                    <tr><td style="img_container"> <a href="{img}" target="_blank"><div style='background-image: url({img});' data-src="{img}" {saliency} class="gallery__image" /></a> </td></tr>
                    <tr><td> {rel}, {w}x{h}, VQ {vq}, Grayscale {grayscale} </td></tr>
                    <tr><td> Page {lr} <a href="{page_url}" class="greenurl" target="_blank">{page_domain}</a> <a href="{page_screenshot_url}" class="greenurl" target="_blank">скриншот</a> </td></tr>
                    {extra_rows}
                </table>
            </div>""".format(
            idx=idx, url=url.index,
            win=url.games_won, loss=url.games_lost, eq=url.games_draw,
            toloka_wins='-'.join(
                ['<span class="wins_count wins_{w}" title="Количество раундов в толоке с {w}-{rw} очками">{cnt}</span>'
                .format(w=w, rw=(6 - w), cnt=cnt) for w, cnt in wins.items()]),
            pwon=url.points_won, buch=url.get_buchholz_coef(), metric=url.metric,
            metrics_values=metrics_values,
            score=url.score(),
            attr_sorts=attr_sorts,
            games=matches[:-1],

            img=url.data.get('image_avatars_url'),
            saliency=' data-saliency="{}" '.format(
                url.data['saliency_map_base64']) if saliency and 'saliency_map_base64' in url.data and url.data['saliency_map_base64'] else '',
            rel=url.data.get('relevance'),
            w=url.data.get('image_width'),
            h=url.data.get('image_height'),
            vq=url.data.get('images_visual_quality_with_queue'),
            grayscale=url.data.get('avatars_avg_gray_deviation', -1) or -1,

            page_url=url.data.get('page_url'),
            page_domain=parsed_uri.netloc,
            lr=url.data.get('page_binary_relevance'),
            page_screenshot_url=url.data.get('page_screenshot_url_desktop'),
            extra_rows=parsed_extra_rows
        )

    def print_gallery(self, sorts=None, saliency=False):
        html = ''
        idx = 0
        for url in sorted(self.urls, key=lambda x: x.score(), reverse=True):
            if saliency and ('saliency_map_base64' not in url.data or not url.data['saliency_map_base64']):
                continue
            html += self.get_html_gallery_item(url, sorts, idx)
            idx += 1
        return html


