#!/usr/bin/env python
# -*- coding: utf-8 -*-

import datetime
import logging
import os
import time

from crypta.lib.python import templater
from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.lib.python.yt import yt_helpers
from crypta.lookalike.lib.python.utils import (
    fields,
    utils,
)
from crypta.lookalike.lib.python.utils.config import (
    config,
    environment,
)
from crypta.lookalike.proto import yt_node_names_pb2


logger = logging.getLogger(__name__)


calculate_lal_metrics_query_template = """
$val_segments_with_counts = '{val_segments_with_counts_table}';

$val_segments_with_counts_flattened = (
    SELECT *
    FROM $val_segments_with_counts
    FLATTEN LIST BY ad_types
);

$relation = (
    SELECT
        dssm.GroupID AS GroupID,
        dssm.aupr / `random`.aupr AS aupr_relation,
    FROM `{dssm_pr_stats_table}` AS dssm
    INNER JOIN `{random_pr_stats_table}` AS `random`
    USING(GroupID)
);
{% for type, flatten in metrics_types %}

$aupr_medians_{{type}} = (
    SELECT
        dssm.{{type}} AS {{type}},
        dssm_median / random_median AS relation_median_aupr,
    FROM (
            SELECT
                {{type}},
                MEDIAN(pr_stats.aupr) AS dssm_median,
            FROM `{dssm_pr_stats_table}` AS pr_stats
            INNER JOIN $val_segments_with_counts{{'_flattened' if flatten else ''}} AS counts
            USING(GroupID)
            GROUP BY counts.{{type}} AS {{type}}
    ) AS dssm
    INNER JOIN (
            SELECT
                {{type}},
                MEDIAN(pr_stats.aupr) AS random_median,
            FROM `{random_pr_stats_table}` AS pr_stats
            INNER JOIN $val_segments_with_counts{{'_flattened' if flatten else ''}} AS counts
            USING(GroupID)
            GROUP BY counts.{{type}} AS {{type}}
    ) AS `random`
    USING({{type}})
);

$metrics_{{type}} = (
    SELECT
        {{type}},
        MEDIAN(relation.aupr_relation) AS median_aupr_relation,
        PERCENTILE(relation.aupr_relation, 0.1) AS pct_10_aupr_relation,
    FROM $relation AS relation
    INNER JOIN $val_segments_with_counts{{'_flattened' if flatten else ''}} AS counts
    USING(GroupID)
    GROUP BY counts.{{type}} AS {{type}}
);

INSERT INTO `{current_metrics_table}`
{% if not flatten %}
WITH TRUNCATE
{% endif %}
SELECT
    Yson::FromDoubleDict(ToDict(AggregateList(AsTuple(metrics.{{type}}, relation_median_aupr))))
        AS relation_median_aupr,
    Yson::FromDoubleDict(ToDict(AggregateList(AsTuple(metrics.{{type}}, median_aupr_relation))))
        AS median_aupr_relation,
    Yson::FromDoubleDict(ToDict(AggregateList(AsTuple(metrics.{{type}}, pct_10_aupr_relation))))
        AS pct_10_aupr_relation,
FROM $aupr_medians_{{type}} AS aupr_medians
JOIN $metrics_{{type}} AS metrics
USING({{type}});
{% endfor %}

{% if calculate_lower_bounds %}
$data_for_lower_bound = (
    SELECT
        `timestamp`,
        Yson::ConvertToDoubleDict(relation_median_aupr)['audience'] AS relation_median_aupr,
        Yson::ConvertToDoubleDict(pct_10_aupr_relation)['audience'] AS pct_10_aupr_relation,
        Yson::ConvertToDoubleDict(median_aupr_relation)['audience'] AS median_aupr_relation,
    FROM `{metrics_table}`
    ORDER BY `timestamp` DESC
    LIMIT {number_of_days}
);

INSERT INTO `{lower_bounds_table}`
WITH TRUNCATE

SELECT
    AVG(relation_median_aupr) -
        3.0 * STDDEV(relation_median_aupr) AS relation_median_aupr,
    AVG(pct_10_aupr_relation) -
        3.0 * STDDEV(pct_10_aupr_relation) AS pct_10_aupr_relation,
    AVG(median_aupr_relation) -
        3.0 * STDDEV(median_aupr_relation) AS median_aupr_relation,
FROM $data_for_lower_bound;
{% endif %}
"""


def get_source_metrics(source_metrics):
    today = str(datetime.date.today())
    metrics_standard = []

    for metric, value in source_metrics.items():
        for segment_type, metric_additional_name in (
                ('audience', ''),
                ('goal', '_goals'),
                ('metrika', '_metrika'),
                ('rmp_goal', '_rmp_goals'),
                ('app', 'app'),
        ):
            metrics_standard.append({
                'fielddate': today,
                'metric': metric + metric_additional_name,
                'value': value[segment_type],
            })

    return metrics_standard


def get_ads_type_metrics(ads_metrics):
    today = str(datetime.date.today())
    metrics_standard = []

    for metric_name, metrics_dict in ads_metrics.items():
        for ad_type, metric_value in metrics_dict.items():
            metrics_standard.append({
                'fielddate': today,
                'metric': metric_name + '_' + ad_type,
                'value': metric_value,
            })

    return metrics_standard


def copy_tables_to_experiments_dir(yt_client):
    for src_table, dst_table in [
        (config.TRAIN_SAMPLE_TABLE, config.EXPERIMENTS_TRAIN_SAMPLE_TABLE),
        (config.USER_DSSM_FEATURES_TABLE, config.EXPERIMENTS_USER_DSSM_FEATURES_TABLE),
        (config.SEGMENTS_FOR_LAL_TRAINING_TABLE, config.EXPERIMENTS_SEGMENTS_FOR_LAL_TRAINING_TABLE),
        (config.SEGMENTS_USER_DATA_STATS_TABLE, config.EXPERIMENTS_SEGMENTS_USER_DATA_STATS_TABLE),
        (config.TEST_SEGMENTS_DSSM_FEATURES_TABLE, config.EXPERIMENTS_SEGMENTS_DSSM_FEATURES_TABLE),
        (config.TEST_SEGMENTS_WITH_COUNTS_TABLE, config.EXPERIMENTS_TEST_SEGMENTS_WITH_COUNTS_TABLE),
        (config.TEST_RANDOM_PR_STATS, config.EXPERIMENTS_RANDOM_PR_STATS),
    ]:
        yt_client.copy(src_table, dst_table, force=True)


def save_model_version(yt_client, yt_node_names, dest_dir):
    yt_helpers.create_folder(yt_client, dest_dir)
    dssm_path = os.path.join(dest_dir, yt_node_names.DssmModelFile)
    dict_path = os.path.join(dest_dir, yt_node_names.SegmentsDictFile)
    yt_client.copy(config.DSSM_MODEL_FILE, dssm_path, force=True)
    yt_client.copy(config.SEGMENTS_DICT_FILE, dict_path, force=True)


prism_lal_correlation_query = '''
{% for key, table in input_tables %}
${{key}} = (
    SELECT
        RANK() OVER w AS idx,
        GroupID,
        Yandexuid,
    FROM `{{table}}`
    WINDOW w AS (
        PARTITION BY GroupID
        ORDER BY Score
    )
);
{% endfor %}

INSERT INTO `{{correlations_table}}`
WITH TRUNCATE

SELECT
    CORRELATION(previous.idx, current.idx) AS `correlation`,
    current.GroupID AS GroupID,
FROM $previous AS previous
INNER JOIN $current AS current
USING (Yandexuid, GroupID)
GROUP BY current.GroupID;


INSERT INTO `{{lower_bounds_table}}`
WITH TRUNCATE

SELECT
    GroupID,
    AVG(`correlation`) - 3.0 * STDDEV(`correlation`) AS correlation_lower_bound,
FROM (
    SELECT
        GroupID,
        `correlation`,
        RANK() OVER w AS days_ago,
    FROM `{{metrics_table}}`
    WINDOW w AS (
        PARTITION BY GroupID ORDER BY fielddate DESC
    )
)
WHERE days_ago <= {{days_for_metric_comparison}}
GROUP BY GroupID;
'''


def check_correlation(nv_params):
    yt_client = utils.get_yt_client(nv_params=nv_params)
    yql_client = utils.get_yql_client(nv_params=nv_params)

    with NirvanaTransaction(yt_client) as transaction, \
            yt_client.TempTable() as correlations_table, \
            yt_client.TempTable() as lower_bounds_table:
        yql_client.execute(
            query=templater.render_template(
                prism_lal_correlation_query,
                vars={
                    'input_tables': [
                        ('previous', config.PREVIOUS_PRISM_LAL),
                        ('current', config.CURRENT_PRISM_LAL),
                    ],
                    'correlations_table': correlations_table,
                    'lower_bounds_table': lower_bounds_table,
                    'days_for_metric_comparison': config.DAYS_TO_CALCULATE_LOWER_BOUND,
                    'metrics_table': config.DATALENS_LOOKALIKE_PRISM_CORRELATION,
                },
            ),
            transaction=str(transaction.transaction_id),
            title='YQL Get prism lals correlation',
        )

        lower_bounds = {}
        for lower_bound in yt_client.read_table(lower_bounds_table):
            logger.info('Lower bound for correlation for {} is: {}'.format(
                lower_bound[fields.group_id],
                lower_bound['correlation_lower_bound'],
            ))
            lower_bounds[lower_bound[fields.group_id]] = lower_bound['correlation_lower_bound']

        stats = []
        stable = True
        for row in list(yt_client.read_table(correlations_table)):
            logger.info('Correlation for {} is: {}'.format(row['GroupID'], row['correlation']))
            if row['correlation'] < lower_bounds[row['GroupID']]:
                stable = False
            else:
                stats.append(row)

        assert stable, 'Correlation for prism input LaL segments is too low to save the model'

        yt_helpers.write_stats_to_yt(
            yt_client=yt_client,
            table_path=config.DATALENS_LOOKALIKE_PRISM_CORRELATION,
            data_to_write=stats,
            schema={'GroupID': 'string', 'correlation': 'double'},
        )


def calculate(nv_params, inputs):
    now_timestamp = str(int(time.time()))
    yt_client = utils.get_yt_client(nv_params=nv_params)
    yql_client = utils.get_yql_client(nv_params=nv_params)
    yt_node_names = yt_node_names_pb2.TYtNodeNames()

    is_experiment = nv_params.get('is_experiment', False)

    with NirvanaTransaction(yt_client) as transaction, yt_client.TempTable() as current_metrics_table:
        calculate_lal_metrics_query = templater.render_template(
            template_text=calculate_lal_metrics_query_template,
            vars={
                'metrics_types': [('segment_type', False), ('ad_types', True)],
            },
        )

        yql_client.execute(
            query=calculate_lal_metrics_query.format(
                timestamp=now_timestamp,
                dssm_pr_stats_table=config.TEST_DSSM_PR_STATS,
                random_pr_stats_table=config.TEST_RANDOM_PR_STATS,
                val_segments_with_counts_table=config.TEST_SEGMENTS_WITH_COUNTS_TABLE,
                metrics_table=utils.get_production_path(config.TEST_LAL_METRICS, nv_params),
                current_metrics_table=current_metrics_table,
                calculate_lower_bounds=not is_experiment,
                lower_bounds_table=config.TEST_LOWER_BOUNDS,
                number_of_days=config.DAYS_TO_CALCULATE_LOWER_BOUND,
            ),
            transaction=str(transaction.transaction_id),
            title='YQL LaL calculate metrics',
        )

        if is_experiment:
            new_model_metrics_table_path = os.path.join(nv_params['working-dir'], 'model_metrics')
            yt_client.copy(current_metrics_table, new_model_metrics_table_path, force=True)
        else:
            metrics_by_segment_source, metrics_by_ads_type = list(yt_client.read_table(current_metrics_table))
            lower_bounds = list(yt_client.read_table(config.TEST_LOWER_BOUNDS))[0]

            to_stable = False
            for metric, value in lower_bounds.items():
                logger.info('Lower bound for {metric} is {value}, current value is {model_value}'.
                            format(metric=metric, value=value, model_value=metrics_by_segment_source[metric]['audience']))
                if metrics_by_segment_source[metric]['audience'] > value:
                    to_stable = True

            if to_stable:
                for metrics_to_write_on_yt in [
                        get_source_metrics(metrics_by_segment_source),
                        get_ads_type_metrics(metrics_by_ads_type),
                ]:
                    yt_helpers.write_stats_to_yt(
                        yt_client=yt_client,
                        table_path=config.DATALENS_LOOKALIKE_QUALITY_TABLE,
                        data_to_write=metrics_to_write_on_yt,
                        schema={
                            'fielddate': 'string',
                            'metric': 'string',
                            'value': 'double',
                        },
                    )

                metrics_by_segment_source.update({'timestamp': now_timestamp})
                yt_client.write_table(yt_client.TablePath(config.TEST_LAL_METRICS, append=True),
                                      [metrics_by_segment_source])

                sandbox_model_link = utils.get_lal_model_source_link(
                    inputs=inputs,
                    file_name=yt_node_names.DssmModelSandboxName,
                )
                yt_client.set_attribute(config.DSSM_MODEL_FILE, yt_node_names.DssmSandboxLinkAttr, sandbox_model_link)
                now_version_dir = os.path.join(config.LOOKALIKE_VERSIONS_DIRECTORY, now_timestamp)
                save_model_version(yt_client, yt_node_names, now_version_dir)

                today = datetime.datetime.fromtimestamp(float(now_timestamp))
                if today.day > config.DATE_TO_SAVE_MODEL_PERMANENTLY:
                    last_version_date = datetime.datetime.strptime(
                        max(yt_client.list(config.LOOKALIKE_MONTHLY_VERSIONS_DIRECTORY)),
                        config.DATE_FORMAT,
                    )
                    if last_version_date.month != today.month:
                        perm_version_dir = os.path.join(
                            config.LOOKALIKE_MONTHLY_VERSIONS_DIRECTORY,
                            today.strftime(config.DATE_FORMAT),
                        )
                        save_model_version(yt_client, yt_node_names, perm_version_dir)

                if environment.environment == 'production':
                    yt_client.move(config.CURRENT_PRISM_LAL, config.PREVIOUS_PRISM_LAL, force=True)

                while True:
                    versions_dir_content = yt_client.list(config.LOOKALIKE_VERSIONS_DIRECTORY, absolute=True)
                    if len(versions_dir_content) > config.VERSIONS_TO_KEEP_NUM:
                        yt_client.remove(min(versions_dir_content), recursive=True)
                    else:
                        break

            if environment.environment == 'production':
                assert to_stable, 'Metrics are not good enough to save the model.'

                copy_tables_to_experiments_dir(yt_client)
                logger.info('Tables for experiments have been successfully updated')
