#!/usr/bin/env python
# -*- coding: utf-8 -*-

from functools import partial
import logging
import os
import time

from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.lib.python.yt import yt_helpers
from crypta.lookalike.lib.python.utils import (
    fields,
    metrics,
    mobile_utils,
)
from crypta.lookalike.lib.python.utils.mobile_config import config as mobile_config
from crypta.lookalike.lib.python.utils.utils import get_feature_name

logger = logging.getLogger(__name__)

calculate_lal_metrics_query_template = """
$apps = (
    SELECT
        app_id,
        id_type,
        source,
        'training' AS app_type
    FROM `{train_apps}`
UNION ALL
    SELECT
        app_id,
        id_type,
        source,
        'validation' AS app_type
    FROM `{val_apps}`
);

$stats = (
    SELECT
        dssm_pr_stats.app_id AS app_id,
        dssm_pr_stats.id_type AS id_type,
        dssm_pr_stats.aupr AS dssm_aupr,
        random_pr_stats.aupr AS random_aupr,
        dssm_pr_stats.aupr / random_pr_stats.aupr AS rel,
        source,
        app_type
    FROM `{dssm_pr_stats}` AS dssm_pr_stats
    INNER JOIN `{random_pr_stats}` AS random_pr_stats
    ON dssm_pr_stats.app_id == random_pr_stats.app_id AND dssm_pr_stats.id_type == random_pr_stats.id_type
    INNER JOIN $apps AS val_apps
    ON dssm_pr_stats.app_id == val_apps.app_id and dssm_pr_stats.id_type == val_apps.id_type
);

INSERT INTO `{metrics_by_app_type}`
WITH TRUNCATE

SELECT
    {timestamp} AS `timestamp`,
    app_type,
    COUNT(*) AS cnt,
    median(rel) AS median_aupr_relation,
    median(dssm_aupr) / median(random_aupr) AS relation_median_aupr,
    percentile(rel, 0.1) AS pct_10_aupr_relation
FROM $stats
GROUP BY app_type;

INSERT INTO `{metrics_by_app_source}`

SELECT
    {timestamp} AS `timestamp`,
    source,
    COUNT(*) AS cnt,
    median(rel) AS median_aupr_relation,
    median(dssm_aupr) / median(random_aupr) AS relation_median_aupr,
    percentile(rel, 0.1) AS pct_10_aupr_relation
FROM $stats
GROUP BY source
ORDER BY `timestamp`, source;

INSERT INTO `{metrics_by_device_type}`

SELECT
    {timestamp} AS `timestamp`,
    id_type,
    COUNT(*) AS cnt,
    median(rel) AS median_aupr_relation,
    median(dssm_aupr) / median(random_aupr) AS relation_median_aupr,
    percentile(rel, 0.1) AS pct_10_aupr_relation
FROM $stats
GROUP BY id_type
ORDER BY `timestamp`, id_type;

INSERT INTO `{metrics_by_app_and_device_type}`
WITH TRUNCATE

SELECT
    String::JoinFromList(AsList(app_type, id_type), '_') AS app_type,
    COUNT(*) AS cnt,
    median(rel) AS median_aupr_relation,
    median(dssm_aupr) / median(random_aupr) AS relation_median_aupr,
    percentile(rel, 0.1) AS pct_10_aupr_relation
FROM $stats
GROUP BY app_type, id_type;
"""


def calculate_distances(nv_params):
    yt_client = mobile_utils.get_yt_client(nv_params=nv_params)
    segments = mobile_utils.get_segments(yt_client, mobile_config.VALIDATION_APPS_DSSM_VECTORS)

    with NirvanaTransaction(yt_client):
        yt_helpers.create_empty_table(
            yt_client=yt_client,
            path=mobile_config.VALIDATION_DSSM_LAL_DISTANCES,
            schema={
                fields.device_id: 'string',
                fields.app_id: 'string',
                fields.cryptaId: 'string',
                fields.id_type: 'string',
                fields.distance: 'double',
                fields.label: 'uint64',
            },
            additional_attributes={'optimize_for': 'scan'},
            force=True,
        )

        yt_client.run_map(
            partial(mobile_utils.calculate_apps_lal_scores_mapper, segments=segments),
            mobile_config.USERS_DSSM_VECTORS,
            mobile_config.VALIDATION_DSSM_LAL_DISTANCES,
            spec={
                'data_size_per_job': 1024 * 1024 * 128,
            },
        )


def calculate_dssm_statistics(nv_params):
    yt_client = mobile_utils.get_yt_client(nv_params=nv_params)
    yql_client = mobile_utils.get_yql_client(nv_params=nv_params)

    with NirvanaTransaction(yt_client) as transaction:
        yql_client.execute(
            query=metrics.calculate_tp_fp_tn_fn_query_template.format(
                order_by=fields.distance,
                group_columns='{}, {}'.format(fields.app_id, fields.id_type),
                input_table=mobile_config.VALIDATION_DSSM_LAL_DISTANCES,
                output_table=mobile_config.VALIDATION_DSSM_SEGMENTS_POINTS,
            ),
            transaction=str(transaction.transaction_id),
            title='YQL LaL calculate tp fp tn fn for DSSM',
        )


def calculate_random_statistics(nv_params):
    yt_client = mobile_utils.get_yt_client(nv_params=nv_params)
    yql_client = mobile_utils.get_yql_client(nv_params=nv_params)

    with NirvanaTransaction(yt_client) as transaction:
        yql_client.execute(
            query=metrics.calculate_tp_fp_tn_fn_query_template.format(
                order_by='Random(id)',
                group_columns='{}, {}'.format(fields.app_id, fields.id_type),
                input_table=mobile_config.VALIDATION_DSSM_LAL_DISTANCES,
                output_table=mobile_config.VALIDATION_RANDOM_SEGMENTS_POINTS,
            ),
            transaction=str(transaction.transaction_id),
            title='YQL LaL calculate tp fp tn fn for RANDOM',
        )


def calculate_pr_auc(nv_params):
    yt_client = mobile_utils.get_yt_client(nv_params=nv_params)
    yql_client = mobile_utils.get_yql_client(nv_params=nv_params)

    with NirvanaTransaction(yt_client) as transaction, yt_client.TempTable() as metrics_by_app_type_table, \
            yt_client.TempTable() as metrics_by_app_and_device_type, \
            yt_client.TempTable() as metrics_by_app_type_merged_table:
        for input_table, output_table in (
                (mobile_config.VALIDATION_RANDOM_SEGMENTS_POINTS, mobile_config.VALIDATION_RANDOM_PR_STATS),
                (mobile_config.VALIDATION_DSSM_SEGMENTS_POINTS, mobile_config.VALIDATION_DSSM_PR_STATS)):
            pr_stats_all_segments = metrics.calculate_pr_stats(
                yt_client=yt_client,
                table_name=input_table,
                mobile=True
            )

            yt_helpers.create_empty_table(
                yt_client=yt_client,
                path=output_table,
                schema={
                    fields.app_id: 'string',
                    fields.id_type: 'string',
                    fields.pr_curve: 'any',
                    fields.aupr: 'double',
                },
                additional_attributes={'optimize_for': 'scan'},
                force=True,
            )

            yt_client.write_table(output_table, pr_stats_all_segments)
            yt_client.run_sort(output_table, sort_by=[fields.app_id, fields.id_type])

        yql_client.execute(
            query=calculate_lal_metrics_query_template.format(
                timestamp=str(int(time.time())),
                train_apps=mobile_config.TRAIN_APPS_TABLE,
                val_apps=mobile_config.VALIDATION_APPS_TABLE,
                dssm_pr_stats=mobile_config.VALIDATION_DSSM_PR_STATS,
                random_pr_stats=mobile_config.VALIDATION_RANDOM_PR_STATS,
                metrics_by_app_type=metrics_by_app_type_table,
                metrics_by_app_source=mobile_config.METRICS_BY_APP_SOURCE,
                metrics_by_device_type=mobile_config.METRICS_BY_DEVICE_TYPE,
                metrics_by_app_and_device_type=metrics_by_app_and_device_type,
            ),
            transaction=str(transaction.transaction_id),
            title='YQL LaL calculate metrics',
        )

        metrics_to_send = []
        for table in (metrics_by_app_type_table, metrics_by_app_and_device_type):
            for row in yt_client.read_table(table):
                for metric_name in mobile_utils.validation_metrics:
                    metrics_to_send.append({
                        'metric': get_feature_name(metric_name, row[fields.app_type]),
                        'value': row[metric_name],
                    })
        yt_helpers.write_stats_to_yt(
            yt_client=yt_client,
            table_path=mobile_config.DATALENS_MOBILE_LAL_QUALITY_TABLE,
            data_to_write=metrics_to_send,
            schema={
                'fielddate': 'string',
                'metric': 'string',
                'value': 'double',
            },
            date=mobile_utils.get_date_from_nv_parameters(nv_params=nv_params),
        )
        yt_client.run_merge(
            source_table=[mobile_config.METRICS_BY_APP_TYPE, metrics_by_app_type_table],
            destination_table=metrics_by_app_type_merged_table,
        )
        yt_client.run_sort(
            source_table=metrics_by_app_type_merged_table,
            destination_table=mobile_config.METRICS_BY_APP_TYPE,
            sort_by=['timestamp'],
        )


def save_tables(nv_params):
    yt_client = mobile_utils.get_yt_client(nv_params=nv_params)

    with NirvanaTransaction(yt_client):
        for train_table, application_table in [
            (
                mobile_config.MERGED_STORES,
                mobile_config.MERGED_STORES_SAVED,
            ),
            (
                mobile_config.TOP_COMMON_APPS,
                mobile_config.TOP_COMMON_APPS_SAVED,
            ),
            (
                mobile_config.APP_DSSM_FEATURES,
                mobile_config.APP_DSSM_FEATURES_SAVED,
            ),
            (
                mobile_config.APPS_FEATURES_FROM_STORES,
                mobile_config.APPS_FEATURES_FROM_STORES_SAVED,
            ),
            (
                mobile_config.APPS_VECTORS_BY_PUBLISHER,
                mobile_config.APPS_VECTORS_BY_PUBLISHER_SAVED,
            ),
            (
                mobile_config.CATEGORY_SEGMENTS_DSSM_WEB_FEATURES_TABLE,
                mobile_config.CATEGORY_SEGMENTS_DSSM_WEB_FEATURES_TABLE_SAVED,
            ),
            (
                mobile_config.CATEGORY2VEC_TABLE,
                mobile_config.CATEGORY2VEC_TABLE_SAVED,
            ),
            (
                mobile_config.CATEGORY_VECTORS_BY_PUBLISHER,
                mobile_config.CATEGORY_VECTORS_BY_PUBLISHER_SAVED,
            ),
            (
                mobile_config.DEFAULT_USER_DSSM_FEATURES_WEB,
                mobile_config.DEFAULT_USER_DSSM_FEATURES_WEB_SAVED,
            ),
        ]:
            yt_client.copy(train_table, application_table, force=True)
            mobile_utils.set_generate_date(yt_client, application_table, nv_params)

        versions = sorted(yt_client.list(mobile_config.MOBILE_LOOKALIKE_VERSIONS_DIRECTORY), reverse=True)
        while len(versions) >= mobile_config.MODEL_VERSIONS_TO_KEEP:
            yt_client.remove(os.path.join(mobile_config.MOBILE_LOOKALIKE_VERSIONS_DIRECTORY, versions[-1]), recursive=True)
            logger.info('Version: {} is removed.'.format(versions.pop()))

        version_path = os.path.join(mobile_config.MOBILE_LOOKALIKE_VERSIONS_DIRECTORY, nv_params['timestamp'])
        yt_helpers.create_folder(yt_client, version_path)
        # model
        for file_name in (
            mobile_config.MODEL_APPLIER_FILE,
            mobile_config.SEGMENTS_FEATURES_DICT_FILE,
            mobile_config.APPS_FEATURES_DICT_FILE,
        ):
            yt_client.copy(
                os.path.join(mobile_config.MODEL_DATA_DIRECTORY, file_name),
                os.path.join(version_path, file_name),
            )
        # vectors
        yt_client.copy(
            mobile_config.USERS_DSSM_VECTORS,
            os.path.join(version_path, 'user_embeddings'),
        )
        yt_client.copy(
            mobile_config.APP_DSSM_VECTORS,
            os.path.join(version_path, 'apps_embeddings'),
        )
