import os

from crypta.lib.python import templater
from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.prism.lib.config import config
from crypta.prism.lib.nirvana import utils
from crypta.prism.services.offline_weighting.lib import table_paths

user_weights_query = '''
PRAGMA DisableAnsiInForEmptyOrNullableItemsCollections;
PRAGMA yt.DataSizePerJob('100m');
PRAGMA yt.MinPublishedAvgChunkSize='4294967296';
PRAGMA yt.PublishedCompressionCodec='brotli_5';
PRAGMA yt.PublishedErasureCodec='lrc_12_2_2';

$longterm_table = '{{ longterm_norm_segment_weights_table }}';

$segment_to_longterm_list = (
    SELECT
        AGGREGATE_LIST(AsTuple(
            prism_segment,
            {
                'segment_week_m1': longterm_segment_week_m1 ?? 0.,
                'segment_week': longterm_segment_week ?? 0.,
                'segment_month': longterm_segment_month ?? 0.,
                'segment_2month': longterm_segment_2month ?? 0.,
                'segment_6month': longterm_segment_6month ?? 0.,
            }
        ))
    FROM $longterm_table
);

$segment_to_longterm = ToDict($segment_to_longterm_list);

$yuid_cluster = (
    SELECT
      yandexuid,
      crypta_id,
      icookie,
      CAST(cluster AS String) AS cluster,
      prism_segment,
      rank_in_cluster,
    FROM `{{ clusters_table }}`
    UNION ALL
    SELECT
        'Unknown' AS cluster,
        'Unknown' AS prism_segment
);
$cluster_to_metrics = SELECT * FROM `{{ cluster_stats_table }}`;

$yuid_to_metrics = (
    SELECT
        yuid_to_cluster.yandexuid AS yandexuid,
        yuid_to_cluster.crypta_id AS crypta_id,
        yuid_to_cluster.icookie AS icookie,
        yuid_to_cluster.cluster AS cluster,
        yuid_to_cluster.prism_segment AS prism_segment,
        yuid_to_cluster.rank_in_cluster AS rank_in_cluster,
        cluster_to_metrics.serp_revenue_sum / cluster_to_metrics.active_search_users AS norm_serp_revenue,
        cluster_to_metrics.CPT_sum AS CPT,
        $segment_to_longterm[yuid_to_cluster.prism_segment] AS longterm_metrics,
    FROM $yuid_cluster AS yuid_to_cluster
    LEFT JOIN $cluster_to_metrics AS cluster_to_metrics
    ON CAST(yuid_to_cluster.cluster AS String) == cluster_to_metrics.cluster
);

$active_clean_search_yuids = SELECT DISTINCT key FROM `{{ user_sessions_pub_table }}`;

$avg = (
    SELECT
        AsStruct(
            AVG(CPT) AS CPT,
            AVG(norm_serp_revenue) AS norm_serp_revenue,
        )
    FROM $yuid_to_metrics
    WHERE 'y' || yandexuid IN $active_clean_search_yuids
);

$yuid_to_norm_metrics = (
    SELECT
        yandexuid,
        crypta_id,
        icookie,
        cluster,
        prism_segment,
        rank_in_cluster,
        norm_serp_revenue / $avg.norm_serp_revenue AS norm_serp_revenue,
        CPT / $avg.CPT AS CPT,
        longterm_metrics,
    FROM $yuid_to_metrics
);

INSERT INTO `{{ user_weights_output }}` WITH TRUNCATE
SELECT * FROM $yuid_to_norm_metrics;
'''


def calculate(
    yql_client,
    transaction,
    date,
    user_weights_output_table,
    clusters_table=None,
    cluster_stats_table=None,
    user_sessions_pub_table=None,
    longterm_norm_segment_weights_table=config.LONGTERM_NORM_SEGMENT_WEIGHTS_TABLE,
):
    clusters_table = clusters_table or os.path.join(config.CLUSTERS_DIR, date)
    cluster_stats_table = cluster_stats_table or os.path.join(config.CLUSTER_STATS_DIR, date)
    user_sessions_pub_table = user_sessions_pub_table or config.USER_SESSIONS_NANO_BY_DATE_TABLE.format(date)

    yql_client.execute(
        templater.render_template(
            user_weights_query,
            vars={
                'clusters_table': clusters_table,
                'cluster_stats_table': cluster_stats_table,
                'user_weights_output': user_weights_output_table,
                'longterm_norm_segment_weights_table': longterm_norm_segment_weights_table,
                'user_sessions_pub_table': user_sessions_pub_table,
            },
        ),
        title='YQL Prism calculate user weights',
        transaction=str(transaction.transaction_id),
    )


def calculate_by_date(yt_client, yql_client, date, nv_output_table_file=None, custom_output_dir=None):
    with NirvanaTransaction(yt_client) as transaction:
        tables = table_paths.resolve(custom_output_dir, date)

        calculate(
            yql_client=yql_client,
            transaction=transaction,
            date=date,
            user_weights_output_table=tables['user_weights'],
            clusters_table=tables['clusters'],
            cluster_stats_table=tables['cluster_stats'],
        )

        if custom_output_dir is None:
            yt_client.set_attribute(tables['user_weights'], 'stage', 'full')

        if nv_output_table_file is not None:
            utils.make_text_output(nv_output_table_file, tables['user_weights'])
