#!/usr/bin/env python
# coding: utf-8

import os
import math
import random
import numpy as np
from collections import defaultdict, OrderedDict

from helpers import play_tournament_games, get_boost_metrics_structure
from helpers import get_sorters_html, get_html_table_header, get_html_queries
from helpers import grayscale_scale, grayscale_avg_scale, square_scale, desktop_viewport_factor, desktop_viewport_factor2, visual_quality_scale

import sys
reload(sys)
sys.setdefaultencoding("utf-8")

QUERIES_LIMIT = 5000
WIN_SCORE_THRESHHOLD = 25000


REL_MAP = {
    'RELEVANT_PLUS': 1.0,
    'RELEVANT_MINUS': 0.5,
    'IRRELEVANT': 0,
}


def square_tanh_scale(iw, ih, growth):
    if not iw or not ih:
        return 0.0
    square = iw * ih
    return math.tanh(growth * square / 1000000.0)


def get_dt_score(url):
    return REL_MAP.get(url.data.get('relevance'), 0) + 0.1 * url.data.get('dt_score', 0)


def get_dt_score_dwell_01(url):
    return REL_MAP.get(url.data.get('relevance'), 0) + 0.1 * url.data.get('dt_score_dwell_01', 0)


def get_dt_score_rel_minus(url):
    return REL_MAP.get(url.data.get('relevance'), 0) + 0.1 * url.data.get('dt_score_rel_minus', 0)


def calc_dbd_score(url):
    return url.data.get('dbd_score', 0)


def calc_click_boost(url):
    return url.data.get('dwelltime_boost', 0)


def boost_relevance(url):
    return REL_MAP.get(url.data.get('relevance'), 0)


def boost_3cg(url):
    target = REL_MAP.get(url.data.get('relevance'), 0)
    if target == 0:
        return 0
    target += grayscale_scale(url.data.get('avatars_max_gray_deviation')) * 1/6.0
    target += square_scale(url.data.get('image_width'), url.data.get('image_height')) * 1/12.0
    target += visual_quality_scale(url.data.get('images_visual_quality_with_queue')) * 1/12.0

    return target


def boost_corsa(url):
    target = REL_MAP.get(url.data.get('relevance'), 0)
    if target == 0:
        return 0

    target += grayscale_scale(url.data.get('avatars_max_gray_deviation')) * 0.1
    target += square_scale(url.data.get('image_width'), url.data.get('image_height')) * 0.1
    target += desktop_viewport_factor(url.data.get('image_width'), url.data.get('image_height')) * 0.025
    target += visual_quality_scale(url.data.get('images_visual_quality_with_queue')) * 0.05
    if url.data.get('page_binary_relevance', 'NO_MARK') == 'RELEVANT':
        target += 0.03

    return target


def boost_corsax_viewport2(url, rel_minus_weight=0.75, boost_weight=0.3):
    target = REL_MAP.get(url.data.get('relevance'), 0)
    if target == 0:
        return 0

    boost = 0
    boost += grayscale_avg_scale(url.data.get('avatars_avg_gray_deviation', 0)) * 0.05
    boost += visual_quality_scale(url.data.get('images_visual_quality_with_queue')) * 0.1
    boost += desktop_viewport_factor2(url.data.get('image_width'), url.data.get('image_height')) * 0.1
    boost += square_scale(url.data.get('image_width'), url.data.get('image_height')) * 0.2
    if url.data.get('kernel_june') is not None:
        boost += url.data.get('kernel_june') * 0.10
    else:
        boost += 0.5 * 0.10

    if url.data.get('page_binary_relevance', 'NO_MARK') == 'RELEVANT':
        boost += 0.10

    if url.data.get('toloka_images_utility', 'NO_MARK') != 'NO_MARK':
        boost += float(url.data.get('toloka_images_utility')) * 0.35
    else:
        boost += 0.4 * 0.35

    return target + boost * boost_weight


def boost_max_overfit_dbd(url):
    target = REL_MAP.get(url.data.get('relevance'), 0)
    if target == 0:
        return 0

    boost = 0
    boost += url.data.get('utility_v2_avg', 0) * 0.67626
    boost += desktop_viewport_factor2(url.data.get('image_width'), url.data.get('image_height')) * 0.00200
    boost += url.data.get('aestetics_mean5', 0) * 0.05514
    if url.data.get('page_binary_relevance', 'NO_MARK') == 'RELEVANT':
        boost += 0.03428
    boost += square_tanh_scale(url.data.get('image_width'), url.data.get('image_height'), 2) * 0.06514
    boost += grayscale_avg_scale(url.data.get('avatars_avg_gray_deviation', 0)) * 0.00068
    boost += (url.data.get('biz_kernel_quantile', 0) or 0) * 0.01678
    boost += url.data.get('images_vq3_v2_avg', 0) * 0.14971

    return target + boost * 0.49999


def custom_boost(url, w):
    target = REL_MAP.get(url.data.get('relevance'), 0)
    if target == 0:
        return 0

    boost = 0
    boost += url.data.get('utility_v2_avg', 0) * w['util']
    boost += desktop_viewport_factor2(url.data.get('image_width'), url.data.get('image_height')) * w['vp']
    boost += url.data.get('aestetics_mean5', 0) * w['aes']
    if url.data.get('page_binary_relevance', 'NO_MARK') == 'RELEVANT':
        boost += w['page_rel']
    boost += square_tanh_scale(url.data.get('image_width'), url.data.get('image_height'), 2) * w['size']
    boost += grayscale_avg_scale(url.data.get('avatars_avg_gray_deviation', 0)) * w['color']
    boost += (url.data.get('biz_kernel_quantile', 0) or 0) * w['kernel']
    boost += url.data.get('images_vq3_v2_avg', 0) * w['vq']

    # print url.data['query_text'], url, boost, target + 0.1 * boost, w
    return target + 0.1 * boost


def metric_boost_corsa_xr(url, boost_weight=0.3):
    item = url.data if hasattr(url, 'data') else url

    if item.get('relevance') == 'IRRELEVANT':
        return 0

    w = {
        'util': 0.5,
        'vp': 0.11421,
        'aes': 0.09114,
        'page_rel': 0.04307,
        'size': 0.08142,
        'color': 0.02186,
        'kernel': 0.03267,
        'vq': 0.11563,
    }

    boost = 0
    boost += item.get('utility_v2_avg', 0) * w['util']
    boost += desktop_viewport_factor2(item.get('image_width'), item.get('image_height')) * w['vp']
    boost += item.get('aestetics_mean5', 0) * w['aes']
    if item.get('page_binary_relevance', 'NO_MARK') == 'RELEVANT':
        boost += w['page_rel']
    boost += square_tanh_scale(item.get('image_width'), item.get('image_height'), 2) * w['size']
    boost += grayscale_avg_scale(item.get('avatars_avg_gray_deviation', 0)) * w['color']
    boost += (item.get('biz_kernel_quantile', 0) or 0) * w['kernel']
    boost += item.get('images_vq3_v2_avg', 0) * w['vq']

    return REL_MAP.get(item.get('relevance'), 0) + boost_weight * boost


boost_types = OrderedDict([
    ('relevance', boost_relevance),
    ('3cg', boost_3cg),
    ('corsa', boost_corsa),
    ('corsax_1_viewport2', boost_corsax_viewport2),
    ('click_boost', calc_click_boost),
    ('dt_score_dwell_01', get_dt_score_dwell_01),
    ('dt_score_dwell_01_plain', get_dt_score_dwell_01),
    ('max_dbd_overfit', boost_max_overfit_dbd),
    ('dbd', calc_dbd_score),
    ('corsa_xr_util_0.500_vp_0.114_aes_0.091_pagerel_0.043_size_0.081_color_0.022_kernel_0.033_vq_0.116', metric_boost_corsa_xr),
])

targets = OrderedDict([
    ('click_boost', calc_click_boost),
    ('dbd', calc_dbd_score),
])


def calc_query_metrics(i, group, sampled_queries_list):
    html = ''
    query = group[0]['query'].decode('utf-8')
    region_id = group[0]['region_id']

    t = play_tournament_games(group)

    print '{}) Query: {:<40} region {:<5} items {}'.format(i, query.decode('utf-8'), region_id, len(group))

    html += '<div class="results__item" data-query="{q}" data-region="{region}">\n'.format(
        q=query.decode('utf-8'),
        region=region_id,
    )
    html += '<h3>Query: "{q}", Region: {region}</h3>\n'.format(
        q=query.decode('utf-8'),
        region=region_id,
    )

    html += get_html_table_header()

    for target, target_fn in targets.items():
        t.cache_sort(target_fn)

    for boost_name, boost_fn in boost_types.items():
        t.cache_sort(boost_fn)

    query_metrics = get_boost_metrics_structure(boost_types, targets)
    for boost_name, boost_fn in boost_types.items():
        for target, target_fn in targets.items():
            if boost_name == 'dt_score_dwell_01':
                pairs_correct, pairs_count, pairs_missed, pairs_unsure = t.compare_dt_score_by_strong_pairs(target_fn, 'dt_score_dwell_01', WIN_SCORE_THRESHHOLD if target == 'dbd' else -25000000)
            else:
                pairs_correct, pairs_count, pairs_missed, pairs_unsure = t.compare_sorts_by_strong_pairs(target_fn, boost_fn, WIN_SCORE_THRESHHOLD if target == 'dbd' else -25000000)
            query_metrics[boost_name][target]['pairs_correct'] = pairs_correct
            query_metrics[boost_name][target]['pairs_total'] = pairs_count
            query_metrics[boost_name][target]['pairs_missed'] = pairs_missed
            query_metrics[boost_name][target]['pairs_unsure'] = pairs_unsure
            query_metrics[boost_name][target]['pairs_predicted'] = 100 * pairs_correct / float(pairs_count or 1)

        html += """
        <tr>
            <td>{boost_name}</td>
            <td>{pairs_correct_dbd:6.0f} / {pairs_total_dbd:6.0f}</td>
            <td>M: {pairs_missed:6.0f}, U: {pairs_unsure:6.0f} ({WIN_POINTS_THRESHHOLD})</td>
            <td>{pairs_dbd:0.2f}</td>
            <td>{pairs_correct_click_boost:6.0f}</td>
            <td>{pairs_total_click_boost:6.0f}</td>
            <td>{pairs_click_boost:0.2f}</td>
        </tr>
        """.format(
            boost_name=boost_name,

            pairs_correct_dbd=query_metrics[boost_name]['dbd']['pairs_correct'],
            pairs_total_dbd=query_metrics[boost_name]['dbd']['pairs_total'],
            pairs_dbd=query_metrics[boost_name]['dbd']['pairs_predicted'],
            pairs_missed=query_metrics[boost_name]['dbd']['pairs_missed'],
            pairs_unsure=query_metrics[boost_name]['dbd']['pairs_unsure'],
            WIN_POINTS_THRESHHOLD=WIN_SCORE_THRESHHOLD,

            # pairs_correct_click_boost=0,
            # pairs_total_click_boost=0,
            # pairs_click_boost=0,
            pairs_correct_click_boost=query_metrics[boost_name]['click_boost']['pairs_correct'],
            pairs_total_click_boost=query_metrics[boost_name]['click_boost']['pairs_total'],
            pairs_click_boost=query_metrics[boost_name]['click_boost']['pairs_predicted'],
        )

    html += '</tbody></table>'

    html += get_sorters_html(boost_types)

    # if SHOW_SALIENCY:
    #     html += '<span class="saliency_toggler">toggle saliency</span>\n'
    html += '<div class="gallery">\n'
    html += t.print_gallery(sorts=boost_types, saliency=False)
    html += '</div>\n\n'
    html += '</div>\n\n'

    return query, region_id, query_metrics, html if query in sampled_queries_list else ''


def main(*args):
    queries_grouped, toloka_data, sampled_queries, token, any_param, html_file = args

    DEBUG = os.uname()[0] == 'Darwin'
    SEED = 30
    random.seed(SEED)
    np.random.seed(SEED)
    np.set_printoptions(linewidth=240)
    np.set_printoptions(threshold=np.inf)
    global QUERIES_LIMIT

    if DEBUG:
        QUERIES_LIMIT = 20

    html = ''
    with open('static_header.html') as f:
        html_header = f.read()

    if DEBUG:
        html_queries, sampled_queries_list = get_html_queries([], queries_grouped, queries_divider=1, add_random=True)
    else:
        html_queries, sampled_queries_list = get_html_queries([], queries_grouped, queries_divider=30, add_random=True)
    html += html_queries

    # =============== отдельные турниры по каждому запросу ===============
    html += '<div class="results">\n'
    metrics = defaultdict(dict)

    results = []
    for idx, group in enumerate(queries_grouped[:QUERIES_LIMIT]):
        results.append(
            calc_query_metrics(idx, group, sampled_queries_list=sampled_queries_list)
        )

    for query, region_id, query_metrics, query_html in results:
        metrics[(query, region_id)] = query_metrics
        html += query_html

    html += '<div class="results__item results__item_visible" data-query="{q}">\n'.format(
        q='Total stats table',
    )

    html += '<h3>Total stats table</h3>'
    html += get_html_table_header()

    out = []
    for boost_name in boost_types.keys():
        pairs_dbd_correct = 0
        pairs_dbd_total = 0
        pairs_click_correct = 0
        pairs_click_total = 0
        pairs_missed = 0
        pairs_unsure = 0

        for idx, ((q, r), query_metrics) in enumerate(metrics.items()):
            pairs_dbd_correct += query_metrics[boost_name]['dbd']['pairs_correct']
            pairs_dbd_total += query_metrics[boost_name]['dbd']['pairs_total']
            pairs_missed += query_metrics[boost_name]['dbd']['pairs_missed']
            pairs_unsure += query_metrics[boost_name]['dbd']['pairs_unsure']
            pairs_click_correct += query_metrics[boost_name]['click_boost']['pairs_correct']
            pairs_click_total += query_metrics[boost_name]['click_boost']['pairs_total']

        pairs_dbd = 100 * pairs_dbd_correct / float(pairs_dbd_total or 1)
        pairs_click_boost = 100 * pairs_click_correct / float(pairs_click_total or 1)
        out.append((boost_name, pairs_dbd))

        html += """
        <tr>
            <td>{boost_name}</td>

            <td>{pairs_correct_dbd:6.0f} / {pairs_total_dbd:6.0f}</td>
            <td>M: {pairs_missed:6.0f}, U: {pairs_unsure:6.0f} ({WIN_POINTS_THRESHHOLD})</td>
            <td>{pairs_dbd:0.2f}</td>

            <td>{pairs_correct_click_boost:6.0f}</td>
            <td>{pairs_total_click_boost:6.0f}</td>
            <td>{pairs_click_boost:0.2f}</td>
        </tr>
        """.format(
            boost_name=boost_name,

            pairs_correct_dbd=pairs_dbd_correct,
            pairs_total_dbd=pairs_dbd_total,
            pairs_dbd=pairs_dbd,
            pairs_missed=pairs_missed,
            pairs_unsure=pairs_unsure,
            WIN_POINTS_THRESHHOLD=WIN_SCORE_THRESHHOLD,

            pairs_correct_click_boost=pairs_click_correct,
            pairs_total_click_boost=pairs_click_total,
            pairs_click_boost=pairs_click_boost,
        )

    html += '</tbody></table>'
    html += '</div>'

    html_file.write(html_header)
    html_file.write(html)

    # if not DEBUG:
    #     print 'Hurge url to html results: {}'.format(upload_file_to_hurge('index.html'))

    return sorted(out, key=lambda x: x[1], reverse=True)
