#!/usr/bin/env python
# coding: utf-8

import os
import sys
reload(sys)
sys.setdefaultencoding("utf-8")

import copy
import math
import random
import datetime
import numpy as np
import functools
from pprint import pprint
from collections import defaultdict, OrderedDict

import operator

from helpers import aggregate_toloka_data, get_boosts_metrics_structure, upload_file_to_hurge
from helpers import get_sorters_html, get_html_table_header, play_tournament_games
from helpers import grayscale_scale, grayscale_avg_scale, square_scale, desktop_viewport_factor, desktop_viewport_factor2, visual_quality_scale

QUERIES_LIMIT = 5000
WIN_SCORE_THRESHHOLD = 25000


REL_MAP = {
    'RELEVANT_PLUS': 1.0,
    'RELEVANT_MINUS': 0.5,
    'IRRELEVANT': 0,
}


def square_tanh_scale(iw, ih, growth):
    if not iw or not ih:
        return 0.0
    square = iw * ih
    return math.tanh(growth * square / 1000000.0)


def get_dt_score(url):
    return REL_MAP.get(url.data.get('relevance'), 0) + 0.1 * url.data.get('dt_score', 0)


def get_dt_score_rel_minus(url):
    return REL_MAP.get(url.data.get('relevance'), 0) + 0.1 * url.data.get('dt_score_rel_minus', 0)


def get_corsax_fixed(url):
    return url.data.get('corsax', 0)
    boost_utility_clickopt1_without_rel04

QUERIES_LIMIT = 5000
WIN_SCORE_THRESHHOLD = 25000


REL_MAP = {
    'RELEVANT_PLUS': 1.0,
    'RELEVANT_MINUS': 0.5,
    'IRRELEVANT': 0,
}


def square_tanh_scale(iw, ih, growth):
    if not iw or not ih:
        return 0.0
    square = iw * ih
    return math.tanh(growth * square / 1000000.0)


def get_dt_score(url):
    return REL_MAP.get(url.data.get('relevance'), 0) + 0.1 * url.data.get('dt_score', 0)


def get_corsax_fixed(url):
    return url.data.get('corsax', 0)


def calc_dbd_score(url):
    return url.data.get('dbd_score', 0)


def calc_click_boost(url):
    return url.data.get('dwelltime_boost', 0)


def boost_relevance(url):
    return REL_MAP.get(url.data.get('relevance'), 0)


def boost_3cg(url):
    target = REL_MAP.get(url.data.get('relevance'), 0)
    if target == 0:
        return 0
    target += grayscale_scale(url.data.get('avatars_max_gray_deviation')) * 1/6.0
    target += square_scale(url.data.get('image_width'), url.data.get('image_height')) * 1/12.0
    target += visual_quality_scale(url.data.get('images_visual_quality_with_queue')) * 1/12.0

    return target


def boost_corsa(url):
    target = REL_MAP.get(url.data.get('relevance'), 0)
    if target == 0:
        return 0

    target += grayscale_scale(url.data.get('avatars_max_gray_deviation')) * 0.1
    target += square_scale(url.data.get('image_width'), url.data.get('image_height')) * 0.1
    target += desktop_viewport_factor(url.data.get('image_width'), url.data.get('image_height')) * 0.025
    target += visual_quality_scale(url.data.get('images_visual_quality_with_queue')) * 0.05
    if url.data.get('page_binary_relevance', 'NO_MARK') == 'RELEVANT':
        target += 0.03

    return target


def boost_corsax_viewport2(url, rel_minus_weight=0.75, boost_weight=0.3):
    target = REL_MAP.get(url.data.get('relevance'), 0)
    if target == 0:
        return 0

    boost = 0
    boost += grayscale_avg_scale(url.data.get('avatars_avg_gray_deviation', 0)) * 0.05
    boost += visual_quality_scale(url.data.get('images_visual_quality_with_queue')) * 0.1
    boost += desktop_viewport_factor2(url.data.get('image_width'), url.data.get('image_height')) * 0.1
    boost += square_scale(url.data.get('image_width'), url.data.get('image_height')) * 0.2
    if url.data.get('kernel_june') is not None:
        boost += url.data.get('kernel_june') * 0.10
    else:
        boost += 0.5 * 0.10

    if url.data.get('page_binary_relevance', 'NO_MARK') == 'RELEVANT':
        boost += 0.10

    if url.data.get('toloka_images_utility', 'NO_MARK') != 'NO_MARK':
        boost += float(url.data.get('toloka_images_utility')) * 0.35
    else:
        boost += 0.4 * 0.35

    return target + boost * boost_weight


def boost_max_overfit_dbd(url):
    target = REL_MAP.get(url.data.get('relevance'), 0)
    if target == 0:
        return 0

    boost = 0
    boost += url.data.get('utility_v2_avg', 0) * 0.67626
    boost += desktop_viewport_factor2(url.data.get('image_width'), url.data.get('image_height')) * 0.00200
    boost += url.data.get('aestetics_mean5', 0) * 0.05514
    if url.data.get('page_binary_relevance', 'NO_MARK') == 'RELEVANT':
        boost += 0.03428
    boost += square_tanh_scale(url.data.get('image_width'), url.data.get('image_height'), 2) * 0.06514
    boost += grayscale_avg_scale(url.data.get('avatars_avg_gray_deviation', 0)) * 0.00068
    boost += (url.data.get('biz_kernel_quantile', 0) or 0) * 0.01678
    boost += url.data.get('images_vq3_v2_avg', 0) * 0.14971

    return target + boost * 0.49999


def custom_boost(url, w):
    target = REL_MAP.get(url.data.get('relevance'), 0)
    if target == 0:
        return 0

    boost = 0
    boost += url.data.get('utility_v2_avg', 0) * w['util']
    boost += desktop_viewport_factor2(url.data.get('image_width'), url.data.get('image_height')) * w['vp']
    boost += url.data.get('aestetics_mean5', 0) * w['aes']
    if url.data.get('page_binary_relevance', 'NO_MARK') == 'RELEVANT':
        boost += w['page_rel']
    boost += square_tanh_scale(url.data.get('image_width'), url.data.get('image_height'), 2) * w['size']
    boost += grayscale_avg_scale(url.data.get('avatars_avg_gray_deviation', 0)) * w['color']
    boost += (url.data.get('biz_kernel_quantile', 0) or 0) * w['kernel']
    boost += url.data.get('images_vq3_v2_avg', 0) * w['vq']

    # print url.data['query_text'], url, boost, target + 0.1 * boost, w
    return target + 0.1 * boost


def metric_boost_corsa_xr(url, boost_weight=0.3):
    item = url.data if hasattr(url, 'data') else url

    if item.get('relevance') == 'IRRELEVANT':
        return 0

    w = {
        'util': 0.5,
        'vp': 0.11421,
        'aes': 0.09114,
        'page_rel': 0.04307,
        'size': 0.08142,
        'color': 0.02186,
        'kernel': 0.03267,
        'vq': 0.11563,
    }

    boost = 0
    boost += item.get('utility_v2_avg', 0) * w['util']
    boost += desktop_viewport_factor2(item.get('image_width'), item.get('image_height')) * w['vp']
    boost += item.get('aestetics_mean5', 0) * w['aes']
    if item.get('page_binary_relevance', 'NO_MARK') == 'RELEVANT':
        boost += w['page_rel']
    boost += square_tanh_scale(item.get('image_width'), item.get('image_height'), 2) * w['size']
    boost += grayscale_avg_scale(item.get('avatars_avg_gray_deviation', 0)) * w['color']
    boost += (item.get('biz_kernel_quantile', 0) or 0) * w['kernel']
    boost += item.get('images_vq3_v2_avg', 0) * w['vq']

    return REL_MAP.get(item.get('relevance'), 0) + boost_weight * boost


boost_types = OrderedDict([
    ('relevance', boost_relevance),
    ('3cg', boost_3cg),
    ('corsa', boost_corsa),
    ('corsax_1_viewport2', boost_corsax_viewport2),
    ('click_boost', calc_click_boost),
    ('decision tree', get_dt_score),
    ('decision_tree_rel_minus', get_dt_score_rel_minus),
    ('max_dbd_overfit', boost_max_overfit_dbd),
    ('dbd', calc_dbd_score),
    ('corsa_xr_util_0.500_vp_0.114_aes_0.091_pagerel_0.043_size_0.081_color_0.022_kernel_0.033_vq_0.116', metric_boost_corsa_xr),
])

targets = OrderedDict([
    ('click_boost', calc_click_boost),
    ('dbd', calc_dbd_score),
])

best_boosts = OrderedDict([
    ('corsa_xr_util_0.500_vp_0.114_aes_0.091_pagerel_0.043_size_0.081_color_0.022_kernel_0.033_vq_0.116', metric_boost_corsa_xr),
    # (last_boost, boost_types[last_boost]),
    ('max_dbd_overfit', boost_max_overfit_dbd),
    ('corsax_1_viewport2', boost_corsax_viewport2),
])


def main(*args):
    queries_grouped, toloka_data, sampled_queries, token, any_param, html_file = args

    DEBUG = os.uname()[0] == 'Darwin'
    SEED = 30
    random.seed(SEED)
    np.random.seed(SEED)
    np.set_printoptions(linewidth=240)
    np.set_printoptions(threshold=np.inf)
    global QUERIES_LIMIT

    if DEBUG:
        QUERIES_LIMIT = 20

    html = ''
    with open('static_header.html') as f:
        html_header = f.read()

    toloka_agg, max_rounds = aggregate_toloka_data(toloka_data)

    boosts = get_boosts_metrics_structure(boost_types, targets)

    # =============== отдельные турниры по каждому запросу ===============
    errors_count = 0
    for idx, group in enumerate(queries_grouped[:QUERIES_LIMIT]):
        query = group[0]['query'].decode('utf-8')
        region_id = group[0]['region_id']

        t = play_tournament_games(group, toloka_agg)

        urls_with_click_boost_count = float(len(t.sort_urls(t.urls, calc_click_boost)))
        print '{:<4}) Query: {:<40} region {:<5} items {}, with click_boost: {:2.0f} - {:2.0f}%'.format(idx, query.decode('utf-8'), region_id, len(group), urls_with_click_boost_count, urls_with_click_boost_count / len(group) * 100)

        for target, target_fn in targets.items():
            t.cache_sort(target_fn)

        for boost_name, boost_fn in boost_types.items():
            t.cache_sort(boost_fn)

        for boost_name, boost_fn in boost_types.items():
            for target, target_fn in targets.items():
                metric_pairs_sort_result = t.compare_sorts_by_strong_pairs(target_fn, boost_fn, WIN_SCORE_THRESHHOLD if target == 'dbd' else -25000000)
                boosts[boost_name][target]['pairs_correct'][(query, region_id)] = metric_pairs_sort_result[0]
                boosts[boost_name][target]['pairs_total'][(query, region_id)] = metric_pairs_sort_result[1]
                boosts[boost_name][target]['pairs_predicted'][(query, region_id)] = (100 * metric_pairs_sort_result[0] / float(metric_pairs_sort_result[1] or 1)) if len(t.urls) >= 10 else -1
                boosts[boost_name][target]['pairs_correct']['sum'] += metric_pairs_sort_result[0]
                boosts[boost_name][target]['pairs_total']['sum'] += metric_pairs_sort_result[1]

    # =============== выбираем топы запросов по метрикам ===============
    SHOW_QUERIES = 20
    sampled_queries_list = set()
    html_queries = '<ul class="queries">\n'
    for target, target_fn in reversed(targets.items()):
        html_queries += '<span class="category">=== Predict {cat} ===</span>'.format(cat=target)
        for boost_name, boost_fn in best_boosts.items():
            print 'Top queries by {} predicting {}'.format(boost_name, target)
            html_queries += '<span class="category">{cat}</span>'.format(cat=boost_name)
            items = [(k, v) for k, v in boosts[boost_name][target]['pairs_predicted'].items() if boosts[boost_name][target]['pairs_total'][k] >= 45]
            sorted_items = sorted(items, key=lambda x: x[1], reverse=True)

            html_queries += '<span class="category category_small">top</span>'
            for (q, r), pairs in sorted_items[:SHOW_QUERIES]:
                # print '   query: {} {}, pairs: {}'.format(q, r, pairs)
                html_queries += '<li class="queries__item" data-query="{q}" data-region="{r}">{q} {p:0.2f}%</li>'.format(q=q, r=r, p=pairs)
                sampled_queries_list.add((q, r))
            html_queries += '<span class="category category_small">bottom</span>'
            for (q, r), pairs in sorted_items[-SHOW_QUERIES:]:
                # print '   query: {} {}, pairs: {}'.format(q, r, pairs)
                html_queries += '<li class="queries__item" data-query="{q}" data-region="{r}">{q} {p:0.2f}%</li>'.format(q=q, r=r, p=pairs)
                sampled_queries_list.add((q, r))
    html_queries += '</ul>\n\n'

    # =============== формируем html результатов ===============
    html += '<div class="results">\n'
    idx = -1
    for group in queries_grouped[:QUERIES_LIMIT]:
        idx += 1
        query = group[0]['query'].decode('utf-8')
        region_id = group[0]['region_id']

        if (query, region_id) in sampled_queries_list:
            t = play_tournament_games(group, toloka_agg)

            html += '<div class="results__item" data-query="{q}" data-region="{region}">\n'.format(
                q=query.decode('utf-8'),
                region=region_id,
            )
            html += '<h3>Query: "{q}", Region: {region}</h3>\n'.format(
                q=query.decode('utf-8'),
                region=region_id,
            )

            html += get_html_table_header()

            for boost_name, boost_fn in boost_types.items():
                html += """
                <tr>
                    <td>{boost_name}</td>

                    <td>{pairs_correct_dbd:6.0f}</td>
                    <td>{pairs_total_dbd:6.0f}</td>
                    <td>{pairs_dbd:0.2f}</td>

                    <td>{pairs_correct_click_boost:6.0f}</td>
                    <td>{pairs_total_click_boost:6.0f}</td>
                    <td>{pairs_click_boost:0.2f}</td>
                </tr>
                """.format(
                    boost_name=boost_name,

                    pairs_correct_dbd=boosts[boost_name]['dbd']['pairs_correct'][(query, region_id)],
                    pairs_total_dbd=boosts[boost_name]['dbd']['pairs_total'][(query, region_id)],
                    pairs_dbd=boosts[boost_name]['dbd']['pairs_predicted'][(query, region_id)],

                    pairs_correct_click_boost=boosts[boost_name]['click_boost']['pairs_correct'][(query, region_id)],
                    pairs_total_click_boost=boosts[boost_name]['click_boost']['pairs_total'][(query, region_id)],
                    pairs_click_boost=boosts[boost_name]['click_boost']['pairs_predicted'][(query, region_id)],
                )

            html += '</tbody></table>'

            html += get_sorters_html(boost_types)

            html += '<div class="gallery">\n'
            html += t.print_gallery(sorts=boost_types)
            html += '</div>\n\n'
            html += '</div>\n\n'

        # if idx >= 2000:
        #     break

    # print 'For all queries, correct pairs {} / {} = {:0.2f}%'.format(pairs_all_queries_correct, pairs_all_queries_total, float(pairs_all_queries_correct) / pairs_all_queries_total * 100)

    html += '<div class="results__item results__item_visible" data-query="{q}">\n'.format(
        q='Total stats table',
    )

    html += '<h3>Total stats table</h3>'
    html += get_html_table_header()

    for boost_name in boost_types.keys():
        pairs_dbd = float(boosts[boost_name]['dbd']['pairs_correct']['sum']) / \
                    (boosts[boost_name]['dbd']['pairs_total']['sum'] or 1) * 100
        pairs_click_boost = float(boosts[boost_name]['click_boost']['pairs_correct']['sum']) / \
                            (boosts[boost_name]['click_boost']['pairs_total']['sum'] or 1) * 100
        html += """
        <tr>
            <td>{boost_name}</td>

            <td>{pairs_correct_dbd:6.0f}</td>
            <td>{pairs_total_dbd:6.0f}</td>
            <td>{pairs_dbd:0.2f}</td>

            <td>{pairs_correct_click_boost:6.0f}</td>
            <td>{pairs_total_click_boost:6.0f}</td>
            <td>{pairs_click_boost:0.2f}</td>
        </tr>
        """.format(
            boost_name=boost_name,

            pairs_correct_dbd=boosts[boost_name]['dbd']['pairs_correct']['sum'],
            pairs_total_dbd=boosts[boost_name]['dbd']['pairs_total']['sum'],
            pairs_dbd=pairs_dbd,

            pairs_correct_click_boost=boosts[boost_name]['click_boost']['pairs_correct']['sum'],
            pairs_total_click_boost=boosts[boost_name]['click_boost']['pairs_total']['sum'],
            pairs_click_boost=pairs_click_boost,
        )

    html += '</tbody></table>'
    html += '</div>'

    html_file.write(html_header)
    html_file.write(html_queries)
    html_file.write(html)

    print 'Errors count: {}, key not found in queries_list'.format(errors_count)
    # if not DEBUG:
    #     print 'Hurge url to html results: {}'.format(upload_file_to_hurge('index.html'))

    return []
