from datetime import timedelta
import os

from crypta.lib.python import templater
from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.lib.python.yt import yt_helpers
from crypta.prism.lib.config import (
    config,
    environment,
)
from crypta.prism.services.offline_weighting.lib import table_paths


clusters_query = '''
PRAGMA DisableAnsiInForEmptyOrNullableItemsCollections;
PRAGMA yt.DataSizePerJob('100m');

$c2_input = (
    SELECT
        CAST(yandexuid AS String) AS yandexuid,
        icookie,
        income_5_segments,
        update_time,
    FROM `{{ yandexuid_profile_export_table }}`
);

$ultima_input = (
    SELECT
        Yandexuid AS yandexuid,
        Score AS score,
    FROM `{{ lookalike_table }}`
    WHERE GroupID == 'ultima_taxi_and_lavka_users'
);

$card_input = (
    SELECT
        Yandexuid AS yandexuid,
        Score AS score,
    FROM `{{ lookalike_table }}`
    WHERE GroupID == 'elite_cards'
);

{% if not is_production %}
$c2_input = (
    SELECT
        c2.*
    FROM $c2_input AS c2
    INNER JOIN $card_input AS card
    USING(yandexuid)
);
{% endif %}

$antifraud = (
    SELECT
        uid,
        uid_type,
    FROM `{{ antifraud_daily_table }}`
);

$max_profiles_update_time = (
    SELECT MAX(update_time) AS max_update_time
    FROM $c2_input
);

$c2_input_with_antifraud = (
    SELECT
        input.*,
        antifraud.uid_type AS antifraud_uid_type,
    FROM $c2_input AS input
    LEFT JOIN ANY $antifraud AS antifraud
    ON 'y' || CAST(input.icookie AS String) == antifraud.uid
    WHERE update_time == $max_profiles_update_time
);

$is_fraud = ($uid_type) -> {
    RETURN
        $uid_type IS NOT NULL
        AND $uid_type != 'clean'
        AND $uid_type != 'yandex_staff';
};

$frauds_output = (
    SELECT
        CAST(yandexuid AS Uint64) AS yandexuid,
        icookie,
    FROM $c2_input_with_antifraud
    WHERE $is_fraud(antifraud_uid_type)
);

INSERT INTO `{{ frauds_output }}` WITH TRUNCATE
SELECT *
FROM $frauds_output;

$calculate_income_integral_score = ($income) -> {
    RETURN 1 - (1 * Coalesce(Yson::ConvertToDoubleDict($income)['A'], 0) +
        0.34 * Coalesce(Yson::ConvertToDoubleDict($income)['B1'], 0) +
        0.26 * Coalesce(Yson::ConvertToDoubleDict($income)['B2'], 0) +
        0.22 * Coalesce(Yson::ConvertToDoubleDict($income)['C1'], 0))
};

$c2_filtered_input = (
    SELECT
        CAST(yandexuid AS String) AS yandexuid,
        CAST(icookie AS String) AS icookie,
        Coalesce($calculate_income_integral_score(income_5_segments), 0) AS integral_score,
        Coalesce(Yson::ConvertToDoubleDict(income_5_segments)['C2'], 0) AS c2_score,
    FROM $c2_input_with_antifraud
    WHERE NOT $is_fraud(antifraud_uid_type)
);

$c2_prepared_input = (
    SELECT
        c2.yandexuid AS yandexuid,
        c2.icookie AS icookie,
        c2.integral_score AS integral_score,
        c2.c2_score AS c2_score,
    FROM $c2_filtered_input AS c2
    INNER JOIN $card_input AS card
    ON c2.yandexuid == card.yandexuid
);

$card_prepared_input = (
    SELECT
        card.yandexuid AS yandexuid,
        card.score AS score,
    FROM $c2_filtered_input AS c2
    INNER JOIN $card_input AS card
    ON c2.yandexuid == card.yandexuid
);

$ultima_prepared_input = (
    SELECT
        ultima.yandexuid AS yandexuid,
        ultima.score AS score,
    FROM $c2_filtered_input AS c2
    INNER JOIN $ultima_input AS ultima
    ON c2.yandexuid == ultima.yandexuid
);

$cnt = (
    SELECT CAST(COUNT(*) AS Double)
    FROM $c2_prepared_input
);

$c2 = (
    SELECT
        yandexuid,
        icookie,
        integral_score,
        RANK() OVER w AS row_rank,
    FROM $c2_prepared_input
    WINDOW w AS (
        ORDER BY c2_score
    )
);

$split_quantile = 0.5;

$c2 = (
    SELECT * WITHOUT integral_score
    FROM $c2
    WHERE row_rank >= $split_quantile * $cnt
UNION ALL
    SELECT
        yandexuid,
        icookie,
        RANK() OVER w AS row_rank,
    FROM $c2
    WHERE row_rank < $split_quantile * $cnt
    WINDOW w AS (
        ORDER BY integral_score
    )
);

$card = (
    SELECT
        yandexuid,
        RANK() OVER w AS row_rank,
    FROM $card_prepared_input
    WINDOW w AS (
        ORDER BY score
    )
);

$ultima = (
    SELECT
        yandexuid,
        RANK() OVER w AS row_rank,
    FROM $ultima_prepared_input
    WINDOW w AS (
        ORDER BY score
    )
);

$weights_by_yandexuid = (
    SELECT
        c2.yandexuid AS yandexuid,
        c2.icookie AS icookie,
        (c2.row_rank / $cnt) * (card.row_rank / $cnt) * (ultima.row_rank / $cnt) AS weight,
    FROM $c2 AS c2
    INNER JOIN $card AS card
    ON c2.yandexuid == card.yandexuid
    INNER JOIN $ultima AS ultima
    ON c2.yandexuid == ultima.yandexuid
);

$weights_info = (
    SELECT
        combined_weights.yandexuid AS yandexuid,
        combined_weights.icookie AS icookie,
        combined_weights.weight AS weight,
        COALESCE(ListLength(yuid_with_all_info.dates), 1) AS days_active,
        COALESCE(yandexuid_cryptaid.target_id, combined_weights.yandexuid) AS crypta_id,
    FROM $weights_by_yandexuid AS combined_weights
    LEFT JOIN `{{ yuid_with_all_info_table }}` AS yuid_with_all_info
    ON combined_weights.yandexuid == yuid_with_all_info.id
    LEFT JOIN `{{ yandexuid_cryptaid_table }}` AS yandexuid_cryptaid
    ON combined_weights.yandexuid == yandexuid_cryptaid.id
);

$yuids_cnt = (
    SELECT COUNT(*)
    FROM $weights_info
);

$weights_by_crypta_id = (
    SELECT
        crypta_id,
        SUM(weight * days_active) / SUM(days_active) AS weight,
    FROM $weights_info
    GROUP BY crypta_id
);

$output = (
    SELECT
        t1.weight AS weight,
        CASE
            WHEN t2.crypta_id == t2.yandexuid THEN Null
            ELSE t2.crypta_id
        END AS crypta_id,
        CAST(t2.yandexuid AS String) AS yandexuid,
        t2.icookie AS icookie,
        CAST((RANK() OVER w - 1.) * 100 AS Int64) / $yuids_cnt + 1 AS cluster,
    FROM $weights_by_crypta_id AS t1
    INNER JOIN $weights_info AS t2
    USING(crypta_id)
    WINDOW w AS (
        ORDER BY t1.weight
    )
);

$yuids_cnt_by_cluster = (
    SELECT
        cluster,
        CAST(COUNT(*) AS Double) AS yuids_cnt,
    FROM $output
    GROUP BY cluster
);

$get_prism_segment = ($cluster) -> {
    RETURN CASE
        WHEN $cluster BETWEEN 1 AND 7 THEN 'p1'
        WHEN $cluster BETWEEN 8 AND 51 THEN 'p2'
        WHEN $cluster BETWEEN 52 AND 91 THEN 'p3'
        WHEN $cluster BETWEEN 92 AND 99 THEN 'p4'
        WHEN $cluster == 100 THEN 'p5'
        ELSE CAST($cluster AS String)
    END
};

$output_with_segment = (
    SELECT
        output.*,
        $get_prism_segment(output.cluster) AS prism_segment,
        CAST(RANK() OVER w AS Double) / yuids_cnt AS rank_in_cluster,
    FROM $output AS output
    INNER JOIN $yuids_cnt_by_cluster AS yuids_cnt_by_cluster
    USING(cluster)
    WINDOW w AS (
        PARTITION BY output.cluster
        ORDER BY weight
    )
);

INSERT INTO `{{ clusters_output }}` WITH TRUNCATE
SELECT *
FROM $output_with_segment;

{% if sizes_table %}
INSERT INTO `{{ sizes_table }}` WITH TRUNCATE
SELECT
    COUNT(*) AS row_count,
    'input' AS table_name,
FROM $c2_input_with_antifraud
UNION ALL
SELECT
    COUNT(*) AS row_count,
    'frauds_output' AS table_name,
FROM $frauds_output
UNION ALL
SELECT
    COUNT(*) AS row_count,
    'clusters_output' AS table_name,
FROM $output_with_segment;
{% endif %}
'''


def calculate(
    yt_client,
    yql_client,
    transaction,
    date,
    clusters_output_table,
    frauds_output_table,
    lookalike_table=None,
    yandexuid_profile_export_table=None,
    antifraud_daily_table=None,
    yuid_with_all_info_table=config.YUID_WITH_ALL_INFO_TABLE,
    yandexuid_cryptaid_table=config.YANDEXUID_CRYPTAID_MATCHING_TABLE,
):
    yandexuid_profile_export_table = yandexuid_profile_export_table or os.path.join(config.YANDEXUID_PROFILES_EXPORT_DIR, date)
    lookalike_table = lookalike_table or os.path.join(config.PRISM_LAL_DIR, date)
    antifraud_daily_table = antifraud_daily_table or config.ATNIFRAUD_EXPORT_BY_DATE_TABLE.format(date)

    with yt_client.TempTable() as sizes_table:
        yql_client.execute(
            templater.render_template(
                clusters_query,
                vars={
                    'yandexuid_profile_export_table': yandexuid_profile_export_table,
                    'lookalike_table': lookalike_table,
                    'antifraud_daily_table': antifraud_daily_table,
                    'clusters_output': clusters_output_table,
                    'frauds_output': frauds_output_table,
                    'yuid_with_all_info_table': yuid_with_all_info_table,
                    'yandexuid_cryptaid_table': yandexuid_cryptaid_table,
                    'sizes_table': sizes_table,
                    'is_production': environment.ENVIRONMENT == 'production',
                },
            ),
            title='YQL Prism calculate clusters',
            transaction=str(transaction.transaction_id),
        )

        row_counts = {row['table_name']: int(row['row_count']) for row in yt_client.read_table(sizes_table)}
        assert row_counts['input'] == row_counts['frauds_output'] + row_counts['clusters_output'], \
            'Expected filtered input and output to be the same length'


def calculate_by_date(yt_client, yql_client, date, custom_output_dir=None):
    with NirvanaTransaction(yt_client) as transaction:
        tables = table_paths.resolve(custom_output_dir, date)

        calculate(
            yt_client=yt_client,
            yql_client=yql_client,
            transaction=transaction,
            date=date,
            clusters_output_table=tables['clusters'],
            frauds_output_table=tables['frauds'],
            lookalike_table=tables['lookalike'],
        )

        if custom_output_dir is None:
            yt_helpers.set_ttl(tables['lookalike'], timedelta(days=config.PRISM_LOOKALIKE_TTL_DAYS), yt_client=yt_client)
            yt_helpers.set_ttl(tables['clusters'], timedelta(days=config.OFFLINE_WEIGHTING_OUTPUT_TTL_DAYS), yt_client=yt_client)
            yt_helpers.set_ttl(tables['frauds'], timedelta(days=config.OFFLINE_WEIGHTING_OUTPUT_TTL_DAYS), yt_client=yt_client)
