#!/usr/bin/env python
# coding: utf-8


import os
import sys
reload(sys)
sys.setdefaultencoding("utf-8")

import copy
import math
import random
import datetime
import numpy as np
from pprint import pprint
from collections import defaultdict, OrderedDict
from functools import partial

from helpers import aggregate_toloka_data, play_tournament_games, get_boosts_metrics_structure, upload_file_to_hurge
from helpers import get_sorters_html, get_html_table_header, get_html_queries
from helpers import boost_3cg, grayscale_scale, square_scale, desktop_viewport_factor, visual_quality_scale


def calc_borders_score(pix, thr_percent=20, thr_value=0.1):
    score = 0

    w, h = pix.shape
    thr_w = round(w * thr_percent / 100)
    thr_h = round(h * thr_percent / 100)

    for i in range(h):
        for j in range(w):
            if j < thr_w or j > (w - thr_w - 1) or i < thr_h or i > (h - thr_h - 1):
                value = pix[j, i] / 255.0
                if value > thr_value:
                    # res = '[{:3.0f}, {:3.0f}] = {:0.4f}'.format(j, i, value)
                    # print '{:22s}'.format(res),
                    score += 1

    borders_square = 2 * thr_w * h + 2 * thr_h * w - 4 * thr_h * thr_w
    print 'w*h {}*{}, px > {}, thr: {:2.0f} = ({}*{}), score: {} / {} = {:0.4f}'.format(
        w, h,
        thr_value,
        thr_percent,
        thr_w, thr_h,
        score,
        borders_square,
        float(score) / borders_square
    )
    return float(score) / borders_square


def calc_dbd_score(url):
    return (url.points_won * 5000 +  # 1) points won
            url.get_buchholz_coef() * 200 +  # 2) Buchholz coef
            int(url.metric * 50))  # 3) metric score

def calc_pscore(url):
    return url.data['pscore']

def calc_click_boost(url):
    return url.data['click_boost']

def boost_corsa(url):
    target = url.data['relevance']
    target += square_scale(url.data['image_width'], url.data['image_height']) * 0.1
    target += grayscale_scale(url.data['max_gray_deviation']) * 0.1
    target += visual_quality_scale(url.data['image_visual_quality']) * 0.05
    target += desktop_viewport_factor(url.data['image_width'], url.data['image_height']) * 0.025
    if url.data.get('page_binary_relevance', 'NO_MARK') == 'RELEVANT':
        target += 0.03
    return target

def boost_stove3_dbd(url):
    target = url.data['relevance']
    # target += square_scale(url.data['image_width'], url.data['image_height']) * 0.0056387408
    target += grayscale_scale(url.data['max_gray_deviation']) * 0.03767*2
    target += visual_quality_scale(url.data['image_visual_quality']) * 0.15965*2
    target += desktop_viewport_factor(url.data['image_width'], url.data['image_height']) * 0.26041*2
    if url.data.get('page_binary_relevance', 'NO_MARK') == 'RELEVANT':
        target += 0.04227*2
    return target

def boost_corsa_with_image_borders(url, thr_percent, thr_value, borders_weight):
    target = boost_corsa(url)

    # TODO: url.data.get('saliency_map_base64', '').replace('\n', '')
    # pix = np.array([])
    # target += borders_weight * calc_borders_score(pix, thr_percent, thr_value)

    return target


boost_types = OrderedDict([
    ('corsa', boost_corsa),
    ('stove3_dbd', boost_stove3_dbd),
    ('click_boost', calc_click_boost),
    ('dbd', calc_dbd_score),
])

for borders_weight in [0.01]:#, 0.025, 0.05, 0.1]:
    for thr_percent in [5, 10]:#, 15, 20]:
        for thr_value in [0.0, 0.1, 0.2]:#, 0.25, 0.3, 0.4, 0.5, 0.6]:
            boost_name = 'corsa_borders_w{}_{}p_{}px'.format(borders_weight, thr_percent, thr_value)
            boost_fn = partial(boost_corsa_with_image_borders, thr_percent=thr_percent, thr_value=thr_value, borders_weight=borders_weight)
            boost_types[boost_name] = boost_fn

targets = OrderedDict([
    ('click_boost', calc_click_boost),
    ('dbd', calc_dbd_score),
])


def main(*args):
    queries_grouped, toloka_data, sampled_queries, token, any_param, html_file = args

    DEBUG = os.uname()[0] == 'Darwin'
    SHOW_SALIENCY = True
    SEED = 30
    random.seed(SEED)
    np.random.seed(SEED)
    np.set_printoptions(linewidth=240)
    np.set_printoptions(threshold=np.inf)

    html = ''
    with open('static_header.html') as f:
        html_header = f.read()

    toloka_agg, max_rounds = aggregate_toloka_data(toloka_data)

    boosts = get_boosts_metrics_structure(boost_types, targets)

    # выбранные запросы
    html_queries, sampled_queries_list = get_html_queries(sampled_queries, toloka_agg, queries_divider=30)
    html += html_queries

    # =============== отдельные турниры по каждому запросу ===============
    html += '<div class="results">\n'
    idx = -1
    errors_count = 0
    out = []
    for group in queries_grouped:
        idx += 1
        query = group[0]['query'].decode('utf-8')
        region_id = group[0]['region_id']
        if not toloka_agg[1][(query, region_id)]:
            # тестируем на семпле запросов
            continue

        t = play_tournament_games(group, toloka_agg)

        urls_with_click_boost_count = float(len(t.sort_urls(t.urls, calc_click_boost)))
        print '{:<4}) Query: {:<40} region {:<5} items {}, with click_boost: {:2.0f} - {:2.0f}%'.format(idx, query.decode('utf-8'), region_id, len(group), urls_with_click_boost_count, urls_with_click_boost_count / len(group) * 100)

        if query in sampled_queries_list:
            html += '<div class="results__item" data-query="{q}" data-region="{region}">\n'.format(
                q=query.decode('utf-8'),
                region=region_id,
            )
            html += '<h3>Query: "{q}", Region: {region}</h3>\n'.format(
                q=query.decode('utf-8'),
                region=region_id,
            )

            html += get_html_table_header()

            for boost_name, boost_fn in boost_types.items():
                for target, target_fn in targets.items():
                    metric_pairs_sort_result = t.compare_sorts(target_fn, boost_fn, by='pairs')
                    boosts[boost_name][target]['pairs_correct'][(query, region_id)] = metric_pairs_sort_result[0]
                    boosts[boost_name][target]['pairs_total'][(query, region_id)] = metric_pairs_sort_result[1]
                    boosts[boost_name][target]['pairs_predicted'][(query, region_id)] = 100 * metric_pairs_sort_result[0] / float(metric_pairs_sort_result[1] or 1)
                    boosts[boost_name][target]['pairs_correct']['sum'] += metric_pairs_sort_result[0]
                    boosts[boost_name][target]['pairs_total']['sum'] += metric_pairs_sort_result[1]

                    # mse = t.compare_sorts(target_fn, boost_fn, by='mse')
                    # boosts[boost_name][target]['mse'][(query, region_id)] = mse
                    # boosts[boost_name][target]['mse']['sum'] += mse
                    #
                    # correlation = t.compare_sorts(target_fn, boost_fn, by='spearmanr')
                    # boosts[boost_name][target]['spearmanr'][(query, region_id)] = math.atan(correlation[0])
                    # boosts[boost_name][target]['spearmanr']['sum'] += math.atan(correlation[0])
                    # boosts[boost_name][target]['spearmanr']['count'] += 1

                html += """
                <tr>
                    <td>{boost_name}</td>
                    <td>{pairs_correct_dbd:6.0f}</td>
                    <td>{pairs_total_dbd:6.0f}</td>
                    <td>{pairs_dbd:0.2f}</td>
                    <td>{pairs_correct_click_boost:6.0f}</td>
                    <td>{pairs_total_click_boost:6.0f}</td>
                    <td>{pairs_click_boost:0.2f}</td>
                </tr>
                """.format(
                    boost_name=boost_name,

                    pairs_correct_dbd=boosts[boost_name]['dbd']['pairs_correct'][(query, region_id)],
                    pairs_total_dbd=boosts[boost_name]['dbd']['pairs_total'][(query, region_id)],
                    pairs_dbd=boosts[boost_name]['dbd']['pairs_predicted'][(query, region_id)],

                    pairs_correct_click_boost=boosts[boost_name]['click_boost']['pairs_correct'][(query, region_id)],
                    pairs_total_click_boost=boosts[boost_name]['click_boost']['pairs_total'][(query, region_id)],
                    pairs_click_boost=boosts[boost_name]['click_boost']['pairs_predicted'][(query, region_id)],
                )

            html += '</tbody></table>'

            html += get_sorters_html(boost_types)

            if SHOW_SALIENCY:
                html += '<span class="saliency_toggler">toggle saliency</span>\n'
            html += '<div class="gallery">\n'
            html += t.print_gallery(sorts=boost_types, saliency=SHOW_SALIENCY)
            html += '</div>\n\n'
            html += '</div>\n\n'

        for url, dbd_score in t.sort_urls(t.urls, calc_dbd_score, return_as_is=True):
            url_data = copy.deepcopy(url.data)
            url_data['dbd_score'] = dbd_score
            out.append(url_data)

        # if idx >= 540:
        #     break

    # print 'For all queries, correct pairs {} / {} = {:0.2f}%'.format(pairs_all_queries_correct, pairs_all_queries_total, float(pairs_all_queries_correct) / pairs_all_queries_total * 100)

    html += '<div class="results__item results__item_visible" data-query="{q}">\n'.format(
        q='Total stats table',
    )

    html += '<h3>Total stats table</h3>'
    html += get_html_table_header()

    for boost_name in boost_types.keys():
        pairs_dbd = float(boosts[boost_name]['dbd']['pairs_correct']['sum']) / \
                    (boosts[boost_name]['dbd']['pairs_total']['sum'] or 1) * 100
        pairs_click_boost = float(boosts[boost_name]['click_boost']['pairs_correct']['sum']) / \
                            (boosts[boost_name]['click_boost']['pairs_total']['sum'] or 1) * 100
        html += """
        <tr>
            <td>{boost_name}</td>

            <td>{pairs_correct_dbd:6.0f}</td>
            <td>{pairs_total_dbd:6.0f}</td>
            <td>{pairs_dbd:0.2f}</td>

            <td>{pairs_correct_click_boost:6.0f}</td>
            <td>{pairs_total_click_boost:6.0f}</td>
            <td>{pairs_click_boost:0.2f}</td>
        </tr>
        """.format(
            boost_name=boost_name,

            pairs_correct_dbd=boosts[boost_name]['dbd']['pairs_correct']['sum'],
            pairs_total_dbd=boosts[boost_name]['dbd']['pairs_total']['sum'],
            pairs_dbd=pairs_dbd,

            pairs_correct_click_boost=boosts[boost_name]['click_boost']['pairs_correct']['sum'],
            pairs_total_click_boost=boosts[boost_name]['click_boost']['pairs_total']['sum'],
            pairs_click_boost=pairs_click_boost,
        )

    html += '</tbody></table>'
    html += '</div>'

    html_file.write(html_header)
    html_file.write(html)

    print 'Errors count: {}, key not found in queries_list'.format(errors_count)
    # if not DEBUG:
    #     print 'Hurge url to html results: {}'.format(upload_file_to_hurge('index.html'))

    return out
