#!/usr/bin/env python
# -*- coding: utf-8 -*-

from functools import partial
import logging
import os
import six

import pandas as pd

from crypta.lib.python.custom_ml.tools.metrics import pandas_to_startrek
from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.lookalike.lib.python.utils import (
    fields,
    metrics as lal_metrics,
    utils,
)
from crypta.lookalike.lib.python.utils.config import config
from crypta.lib.python.prism_quality import check_quality


logger = logging.getLogger(__name__)


get_aupr_relations_to_base_query = """
INSERT INTO `{relations_table}`
WITH TRUNCATE
SELECT
    new_model.GroupID AS GroupID,
    new_model.aupr / base_model.aupr AS aupr_relation,
    counts.segment_type AS segment_type,
FROM `{new_pr_stats_table}` AS new_model
INNER JOIN `{base_pr_stats_table}` AS base_model
ON base_model.GroupID == new_model.GroupID
INNER JOIN `{val_segments_with_counts_table}` AS counts
ON counts.GroupID == new_model.GroupID
ORDER BY segment_type;
"""


def calculate_ci(nv_params):
    yt_client = utils.get_yt_client(nv_params=nv_params)
    yql_client = utils.get_yql_client(nv_params=nv_params)

    with NirvanaTransaction(yt_client) as transaction, \
         yt_client.TempTable() as relations_table:
        yql_client.execute(
            query=get_aupr_relations_to_base_query.format(
                new_pr_stats_table=config.TEST_DSSM_PR_STATS,
                base_pr_stats_table=config.TEST_BASELINE_PR_STATS,
                relations_table=relations_table,
                val_segments_with_counts_table=config.TEST_SEGMENTS_WITH_COUNTS_TABLE,
            ),
            transaction=str(transaction.transaction_id),
            title='YQL Get aupr relations to base model',
        )
        relations_df = pd.DataFrame(yt_client.read_table(relations_table))
        result_ci = {}
        for segment_type in config.SEGMENT_TYPES:
            values = relations_df[relations_df[fields.segment_type] == segment_type][fields.aupr_relation] - 1
            estimate, ci_left, ci_right = lal_metrics.get_ci_with_bootstrap(values)
            result_ci[segment_type] = {'estimation': estimate, 'ci_left': ci_left, 'ci_right': ci_right}

        output = pandas_to_startrek(pd.DataFrame.from_records(result_ci).transpose(), add_index=True)
        ci_file_path = os.path.join(nv_params['working-dir'], 'relation_median_aupr_ci')
        yt_client.write_file(ci_file_path, six.ensure_binary(output), force_create=True)


def coloring_formatter(col, value, colored_cols, threshold=1):
    if col not in colored_cols or abs(value) < threshold:
        return '{}'.format(round(value, 3))
    return '!!({}){}%!!'.format('green' if value > 0 else 'red', round(value, 2))


def calculate_prism_scores(nv_params):
    working_dir = nv_params.get('working-dir')
    assert working_dir is not None, 'This function should only be used for experiments'

    yt_client = utils.get_yt_client(nv_params=nv_params)
    yql_client = utils.get_yql_client(nv_params=nv_params)

    prism_dir = os.path.join(working_dir, 'prism')
    model_dirs, prism_weights_tables, dates = [], [], []
    model_names = ['prod', 'new_model']
    for model_dirname in model_names:
        model_dirs.append(os.path.join(prism_dir, model_dirname))
        dates.append(max(yt_client.list(os.path.join(model_dirs[-1], 'clusters'))))
        prism_weights_tables.append(os.path.join(model_dirs[-1], 'clusters', dates[-1]))
    metrics = check_quality(yt_client, yql_client, prism_weights_tables, dates, model_dirs)

    df_list = [pd.DataFrame.from_records(metric) for metric in metrics]
    merged_df = pd.merge(
        *df_list,
        on=['metric_type', 'os', 'fielddate'],
        suffixes=list(map(lambda name: '_' + name, model_names)),
        validate='one_to_one'
    )
    merged_df['abs_diff'] = merged_df['metric_new_model'] - merged_df['metric_prod']
    merged_df['diff, %'] = (merged_df['abs_diff']) * 100 / merged_df['metric_prod']

    output = pandas_to_startrek(
        df=merged_df,
        formatter=partial(coloring_formatter, colored_cols={'diff, %', 'abs_diff'}),
    )
    output_file_path = os.path.join(prism_dir, 'prism_metrics_comparison')
    yt_client.write_file(output_file_path, six.ensure_binary(output), force_create=True)
