#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging

from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.lib.python.yt import yt_helpers
from crypta.lookalike.lib.python.utils.config import config
from crypta.lookalike.lib.python.utils import utils


logger = logging.getLogger(__name__)

train_validation_split_query_template = """
$segments_with_ranks = (
    SELECT
        GroupID,
        ids_cnt,
        segment_type,
        ad_types,
        (ROW_NUMBER() OVER w) <= {segments_num_for_test} AS is_test,
        (ROW_NUMBER() OVER w) <= {segments_num_for_test} + {segments_num_for_validation} AS is_validation,
    FROM `{segments_counts_table}`
    WINDOW w AS (
        PARTITION BY segment_type
        ORDER BY RANDOM(GroupID)
    )
);

$test_segments_with_counts = (
    SELECT
        GroupID,
        ids_cnt,
        segment_type,
        ad_types,
    FROM $segments_with_ranks
    WHERE is_test
);

INSERT INTO `{test_segments_with_counts_table}`
WITH TRUNCATE

SELECT *
FROM $test_segments_with_counts
ORDER BY GroupID;

$segments_users_train = (
    SELECT
        CAST(user_segment.IdValue AS Uint64) AS IdValue,
        user_segment.IdType AS IdType,
        user_segment.GroupID AS GroupID,
        user_segment.segment_type AS segment_type,
        user_segment.ad_types AS ad_types,
    FROM `{user_segment_table}` AS user_segment
    LEFT ONLY JOIN $test_segments_with_counts AS test_group_ids
    USING(GroupID)
    WHERE user_segment.segment_type != 'rmp_goal'
);

$segments_users_test = (
    SELECT
        CAST(user_segment.IdValue AS Uint64) AS IdValue,
        user_segment.IdType AS IdType,
        user_segment.GroupID AS GroupID,
        user_segment.segment_type AS segment_type,
        user_segment.ad_types AS ad_types,
    FROM `{user_segment_table}` AS user_segment
    INNER JOIN $test_segments_with_counts AS test_group_ids
    USING(GroupID)
);

INSERT INTO `{test_output_table}`
WITH TRUNCATE

    SELECT
        user_dssm_features.yandexuid AS yandexuid,
        test_sample.GroupID AS GroupID,
        test_sample.segment_type AS segment_type,
        test_sample.ad_types AS ad_types,
    FROM `{user_dssm_features_table}` AS user_dssm_features
    INNER JOIN $segments_users_test AS test_sample
    ON user_dssm_features.yandexuid == test_sample.IdValue
    WHERE test_sample.IdType == 'yandexuid'
UNION ALL
    SELECT
        user_dssm_features.yandexuid AS yandexuid,
        test_sample.GroupID AS GroupID,
        test_sample.segment_type AS segment_type,
        test_sample.ad_types AS ad_types,
    FROM `{user_dssm_features_table}` AS user_dssm_features
    INNER JOIN $segments_users_test AS test_sample
    ON user_dssm_features.cryptaId == test_sample.IdValue
    WHERE test_sample.IdType == 'crypta_id';

$training_ratio = (
    SELECT 1.0 * {positives_volume} / COUNT(*) FROM $segments_users_train
);

$train_segments_with_counts = (
    SELECT
        GroupID,
        ids_cnt,
        segment_type,
        ad_types,
        is_validation,
    FROM $segments_with_ranks
    WHERE NOT is_test
);

$segments_users_train = (
    SELECT
        IdValue,
        IdType,
        GroupID,
        segment_type,
        ad_types,
        ROW_NUMBER() OVER w AS row_rank,
    FROM $segments_users_train
    WINDOW w AS (
        PARTITION BY GroupID
        ORDER BY RANDOM(IdValue)
    )
);

$corrected_threshold = ($threshold) -> {{
    return CASE
        WHEN $threshold >= {min_users_per_segment} THEN $threshold
        ELSE {min_users_per_segment}
    END;
}};

$segments_users_train = (
    SELECT
        train_sample.IdValue AS IdValue,
        train_sample.IdType AS IdType,
        train_sample.GroupID AS GroupID,
        train_sample.segment_type AS segment_type,
        train_sample.ad_types AS ad_types,
    FROM $segments_users_train AS train_sample
    INNER JOIN $train_segments_with_counts AS train_sample_counts
    USING(GroupID)
    WHERE train_sample.row_rank <= Math::Round($corrected_threshold(train_sample_counts.ids_cnt * $training_ratio))
);

$train_segments_with_counts = (
    SELECT
        GroupID,
        segment_type,
        ad_types,
        Math::Round($corrected_threshold(ids_cnt * $training_ratio)) AS ids_cnt,
        is_validation,
    FROM $train_segments_with_counts
);

INSERT INTO `{train_segments_with_counts_table}`
WITH TRUNCATE

SELECT *
FROM $train_segments_with_counts
ORDER BY GroupID;


$remove_target_app = ($user_apps, $GroupID) -> (
    String::JoinFromList(DictKeys(
        SetDifference(
            ToSet(String::SplitToList($user_apps, ' ')),
            ToSet([
                CAST(Digest::Md5HalfMix(
                    SUBSTRING($GroupID, NULL, CAST((LENGTH($GroupID) - 6) AS UINT32))
                ) AS string)
            ]),
        )
    ), ' ')
);

INSERT INTO `{train_output_table}`
WITH TRUNCATE

    SELECT
        train_sample.GroupID AS GroupID,
        train_sample.segment_type AS segment_type,
        train_sample.ad_types AS ad_types,
        user_dssm_features.*,
    FROM `{user_dssm_features_table}` AS user_dssm_features
    INNER JOIN $segments_users_train AS train_sample
    ON user_dssm_features.yandexuid == train_sample.IdValue
    WHERE train_sample.IdType == 'yandexuid'
UNION ALL
    SELECT
        train_sample.GroupID AS GroupID,
        train_sample.segment_type AS segment_type,
        train_sample.ad_types AS ad_types,
        $remove_target_app(user_affinitive_apps, GroupID) AS user_affinitive_apps,
        user_dssm_features.* WITHOUT user_dssm_features.user_affinitive_apps,
    FROM `{user_dssm_features_table}` AS user_dssm_features
    INNER JOIN $segments_users_train AS train_sample
    ON user_dssm_features.cryptaId == train_sample.IdValue
    WHERE train_sample.IdType == 'crypta_id';
"""


metrics_query_template = """
$segments_counts_stats = (
    SELECT
        segment_type,
        COUNT(*) AS segments_counts,
    FROM `{segments_counts_table}`
    GROUP BY segment_type
);

$segments_counts_stats = (
    SELECT
        stats.*,
        validation_rows,
    FROM $segments_counts_stats AS stats
    JOIN (
        SELECT
            segment_type,
            SUM(ids_cnt) AS validation_rows,
        FROM $test_segments_with_counts
        GROUP BY segment_type
    ) AS test
    USING(segment_type)
);

INSERT INTO `{segments_counts_stats_table}`
WITH TRUNCATE

SELECT
    stats.*,
    train_rows,
FROM $segments_counts_stats AS stats
JOIN (
    SELECT
        segment_type,
        SUM(ids_cnt) AS train_rows,
    FROM $train_segments_with_counts
    GROUP BY segment_type
) AS train
USING(segment_type);

$flattened_by_goal_type = (
    SELECT
        segments_counts.GroupID AS GroupID,
        segments_counts.ad_types AS ad_type,
        train.ids_cnt ?? 0 AS train_ids_cnt,
        test.ids_cnt ?? 0 AS val_ids_cnt,
    FROM `{segments_counts_table}` AS segments_counts
    FLATTEN BY segments_counts.ad_types
    LEFT JOIN $train_segments_with_counts AS train
    ON segments_counts.GroupID == train.GroupID
    LEFT JOIN $test_segments_with_counts AS test
    ON segments_counts.GroupID == test.GroupID
);

INSERT INTO `{ads_types_stats_table}`
WITH TRUNCATE

SELECT
    ad_type,
    COUNT(*) AS segments_counts,
    SUM(train_ids_cnt) AS train_rows,
    SUM(val_ids_cnt) AS validation_rows,
FROM $flattened_by_goal_type
GROUP BY ad_type;
"""


def write_stats_on_yt(yt_client, input_table_path, column_name):
    stats = list(yt_client.read_table(input_table_path))
    for stats_by_type in stats:
        count_type = stats_by_type[column_name]
        del stats_by_type[column_name]

        for counter_name, count in stats_by_type.items():
            yt_helpers.write_stats_to_yt(
                yt_client=yt_client,
                table_path=config.DATALENS_LOOKALIKE_COUNTS_TABLE,
                data_to_write={
                    'counter_name': counter_name + '_' + count_type,
                    'count': int(count),
                },
                schema={
                    'counter_name': 'string',
                    'count': 'uint64',
                },
            )


def split(nv_params):
    yt_client = utils.get_yt_client(nv_params=nv_params)
    yql_client = utils.get_yql_client(nv_params=nv_params)

    is_experiment = nv_params.get('is_experiment', False)

    with NirvanaTransaction(yt_client) as transaction:
        query_template = train_validation_split_query_template + (metrics_query_template if not is_experiment else '')
        yql_client.execute(
            query=query_template.format(
                segments_counts_table=config.SEGMENTS_WITH_COUNTS_TABLE,
                segments_num_for_test=config.SEGMENTS_NUM_FOR_TEST,
                segments_num_for_validation=config.SEGMENTS_NUM_FOR_VALIDATION,
                min_users_per_segment=config.MIN_USERS_PER_SEGMENT,
                positives_volume=int(config.POSITIVES_VOLUME),
                train_segments_with_counts_table=config.TRAIN_SEGMENTS_WITH_COUNTS_TABLE,
                test_segments_with_counts_table=config.TEST_SEGMENTS_WITH_COUNTS_TABLE,
                user_segment_table=config.SEGMENTS_FOR_LAL_TRAINING_TABLE,
                user_dssm_features_table=config.USER_DSSM_FEATURES_TABLE,
                segments_counts_stats_table=config.SEGMENTS_STATS,
                ads_types_stats_table=config.ADS_TYPES_STATS,
                test_output_table=config.TEST_USERS_TABLE,
                train_output_table=config.POSITIVES_WITH_DSSM_FEATURES_TABLE,
            ),
            transaction=str(transaction.transaction_id),
            title='YQL LaL train validation split',
        )

        if not is_experiment:
            yt_helpers.write_stats_to_yt(
                yt_client=yt_client,
                table_path=config.DATALENS_LOOKALIKE_COUNTS_TABLE,
                data_to_write={
                    'counter_name': 'val_rows',
                    'count': yt_client.row_count(config.TEST_USERS_TABLE),
                },
                schema={
                    'counter_name': 'string',
                    'count': 'uint64',
                },
            )
            write_stats_on_yt(yt_client, config.SEGMENTS_STATS, 'segment_type')
            write_stats_on_yt(yt_client, config.ADS_TYPES_STATS, 'ad_type')
