import os

from crypta.lib.python import templater
from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.prism.lib.config import config
from crypta.prism.lib.utils import metrics
from crypta.profile.lib import date_helpers


priors_query = '''
$end_date = '{{ date }}';
$days = {{ prior_sample_period }};
$start_date = CAST(CAST($end_date AS Date) - DateTime::IntervalFromDays($days - 1) AS String);

-- 1 / 4: Get features by date
$features_and_cluster_by_user_by_date = (
    SELECT
        features_and_cluster_by_user.*,
    FROM RANGE(
        `{{ features_and_cluster_by_user_dir }}`,
        $start_date,
        $end_date,
    ) AS features_and_cluster_by_user
);

$features_and_cluster_by_user_by_date_extended = (
    SELECT
        icookie,
        `date`,
        cluster,
        Geo::GetParents(CAST(region AS Int32)) AS region,
        $extend(browser) AS browser,
        $extend(device_vendor) AS device_vendor,
        $extend(os_family) AS os_family,
        $extend(device_name) AS device_name,
        $extend(screen_info) AS screen_info,
    FROM $features_and_cluster_by_user_by_date
    WHERE
        cluster IS NOT NULL AND
        cluster != '-1'
);

$features_and_cluster_by_user_by_date_extended = (
    SELECT *
    FROM $features_and_cluster_by_user_by_date_extended
    FLATTEN LIST BY (
        browser,
        device_vendor,
        os_family,
        device_name,
        region,
        screen_info,
    )
);

$info_cluster_avg_raw = (
    SELECT
        AVG(CAST(cluster AS Double)) AS avg_cluster,
        COUNT(*) AS `count`,
        browser,
        device_vendor,
        os_family,
        device_name,
        region,
        screen_info,
    FROM $features_and_cluster_by_user_by_date_extended
    GROUP BY
        browser,
        device_vendor,
        os_family,
        device_name,
        region,
        screen_info
    HAVING COUNT(*) >= {{ min_prior_size }}
);

-- 2 / 4: Build icookie to raw avg value
$icookie_date_raw_avg = (
    SELECT
        user.icookie AS icookie,
        user.`date` AS `date`,
        MIN_BY(priors.avg_cluster, priors.`count`) AS avg_cluster,
    FROM $features_and_cluster_by_user_by_date_extended AS user
    LEFT JOIN ANY $info_cluster_avg_raw AS priors
    USING(
        browser,
        device_vendor,
        os_family,
        device_name,
        region,
        screen_info
    )
    GROUP BY user.icookie, user.`date`
);

-- 3 / 4: Calc percentiles; match raw avg value to ranked avg prior cluster
$cnt = (
    SELECT COUNT(*)
    FROM $icookie_date_raw_avg
);

$avg_to_cluster_raw = (
    SELECT
        icookie,
        avg_cluster,
        CAST((RANK() OVER w - 1.) * 100 AS Int64) / $cnt + 1 AS cluster,
    FROM $icookie_date_raw_avg
    WINDOW w AS (
        ORDER BY avg_cluster
    )
);

$avg_to_cluster = (
    SELECT
        cluster,
        MIN(avg_cluster) AS min_avg,
        MAX(avg_cluster) AS max_avg,
    FROM $avg_to_cluster_raw
    GROUP BY cluster
);

-- 4 / 4: Build ranked avg_screen priors.
$cluster_lower_bounds = (
    SELECT AGGREGATE_LIST(AsTuple(cluster, min_avg))
    FROM $avg_to_cluster
);

$final_cluster = ($avg_cluster) -> {
    $filtered = ListFilter(
        $cluster_lower_bounds,
        ($x) -> {
            RETURN $avg_cluster > $x.1;
        },
    );
    RETURN IF(
        ListLength($filtered) > 0,
        ListLast($filtered).0,
        1,
    );
};

$get_prism_segment = ($cluster) -> {
    RETURN CASE
        {% for segment_bound in prism_segment_bounds %}
        {% if loop.index < prism_segment_bounds | length %}
        {# loop.index starts with 1; so prism_segment_bounds[loop.index] is the next segment bound #}
        WHEN $cluster BETWEEN {{ segment_bound }} AND {{ prism_segment_bounds[loop.index] - 1 }}
        {% else %}
        WHEN $cluster >= {{ segment_bound }}
        {% endif %}
        THEN {{ loop.index }}
        {% endfor %}
        ELSE NULL
    END
};

$get_prism_weight = ($cluster) -> {
    RETURN CASE
        {% for segment_bound in prism_segment_bounds %}
        WHEN
        {% if loop.index < prism_segment_bounds | length %}
        $cluster BETWEEN {{ segment_bound }} AND {{ prism_segment_bounds[loop.index] - 1 }}
        {% else %}
        $cluster >= {{ segment_bound }}
        {% endif %}
        THEN {{ prism_segment_weights[loop.index] }}
        {% endfor %}
        ELSE NULL
    END
};

INSERT INTO `{{ output_table }}` WITH TRUNCATE
SELECT
    browser,
    device_name,
    device_vendor,
    os_family,
    region,
    screen_info,
    `count`,
    $final_cluster(avg_cluster) AS prior_cluster,
    $get_prism_segment($final_cluster(avg_cluster)) AS prior_segment,
    $get_prism_weight($final_cluster(avg_cluster)) AS prior_weight,
FROM $info_cluster_avg_raw
ORDER BY `count`
'''


def calculate(
    yt_client,
    yql_client,
    date,
    output_table=None,
    features_and_cluster_by_user_dir=config.FEATURES_AND_CLUSTER_BY_USER_DIR,
):
    with NirvanaTransaction(yt_client) as transaction:
        output_table = output_table or os.path.join(config.PRISM_PRIORS_DIR, date)

        expected_dates = {date_helpers.get_date_from_past(date, days=days) for days in range(config.PRIOR_SAMPLE_PERIOD_DAYS)}
        existing_dates = set(yt_client.list(features_and_cluster_by_user_dir))
        missing_dates = expected_dates - existing_dates
        assert len(missing_dates) == 0, 'Some dates are missing: {}'.format(missing_dates)

        prism_segment_weights = {
            int(row['prism_segment'][1:]): row['norm_serp_revenue'] for row in yt_client.read_table(
                config.PRISM_CLUSTER_MAPPING_TABLE
            ) if row['prism_segment'].startswith('p')
        }

        yql_client.execute(
            '\n'.join([
                metrics.prior_utils_query,
                templater.render_template(
                    priors_query,
                    vars={
                        'date': date,
                        'min_prior_size': config.MIN_PRIOR_SIZE,
                        'prior_sample_period': config.PRIOR_SAMPLE_PERIOD_DAYS,
                        'prism_segment_bounds': config.PRISM_SEGMENT_LOWER_CLUSTER_BOUNDS,
                        'prism_segment_weights': prism_segment_weights,
                        'features_and_cluster_by_user_dir': features_and_cluster_by_user_dir,
                        'output_table': output_table,
                    },
                ),
            ]),
            title='YQL Prism priors calculation',
            transaction=str(transaction.transaction_id),
        )
