import os

import numpy as np

from crypta.lib.python import templater
from crypta.profile.utils.config import config


OPERATING_SYSTEMS = ('android', 'ios', 'macos', 'other', 'total', 'windows')

prism_cost_metrics_query = """
${joined}{idx} = (
    SELECT
        metric.os AS os,
        CAST(prism.cluster AS Uint64) AS cluster,
        SUM(metric.cost) AS cost,
    FROM $prism{idx} AS prism
    INNER JOIN `{metric_input_table}` AS metric
    USING(yandexuid)
    GROUP BY metric.os, prism.cluster
UNION ALL
    SELECT
        'total' AS os,
        CAST(prism.cluster AS Uint64) AS cluster,
        SUM(metric.cost) AS cost,
    FROM $prism{idx} AS prism
    INNER JOIN `{metric_input_table}` AS metric
    USING(yandexuid)
    GROUP BY prism.cluster
);

${sum_cost}{idx} = (
    SELECT
        os,
        CAST(SUM(cost) AS Double) AS cost,
    FROM ${joined}{idx}
    GROUP BY os
);

INSERT INTO `{output_table}`
WITH TRUNCATE

SELECT
    joined.os AS os,
    joined.cluster AS cluster,
    SUM(joined.cost) OVER w / sum_cost.cost AS cost_ratio_cumsum,
FROM ${joined}{idx} AS joined
INNER JOIN ${sum_cost}{idx} AS sum_cost
USING(os)
WINDOW w AS (
    PARTITION BY joined.os
    ORDER BY joined.cluster DESC
)
ORDER BY os, cluster;
"""

prism_share_query = """
$prism_share{idx} = (
    SELECT
        share.os AS os,
        prism.prism_segment AS prism_segment,
        CAST(SUM(share.google_visits) AS Double) / (SUM(share.google_visits) + SUM(share.yandex_visits)) AS google_share,
    FROM $prism{idx} AS prism
    INNER JOIN `{yandex_google_visits_table}` AS share
    USING(yandexuid)
    GROUP BY share.os, prism.prism_segment
UNION ALL
    SELECT
        'total' AS os,
        prism.prism_segment AS prism_segment,
        CAST(SUM(share.google_visits) AS Double) / (SUM(share.google_visits) + SUM(share.yandex_visits)) AS google_share,
    FROM $prism{idx} AS prism
    INNER JOIN `{yandex_google_visits_table}` AS share
    USING(yandexuid)
    GROUP BY prism.prism_segment
);

INSERT INTO `{output_table}`
WITH TRUNCATE

SELECT *
FROM $prism_share{idx}
ORDER BY os, prism_segment;
"""

prism_quality_query_template = """
{% for idx in range(input_tables| length) %}
$prism{{idx}} = (
    SELECT
        CAST(yandexuid AS Uint64) AS yandexuid,
        CAST(cluster AS Uint64) AS cluster,
        prism_segment,
    FROM `{{input_tables[idx]}}`
    WHERE CAST(cluster AS String) != 'Unknown'
);

{{prism_cost_adv_queries[idx]}}
{{prism_cost_gmv_queries[idx]}}
{{prism_share_queries[idx]}}
{% endfor %}
"""


def calculate_prism_quality(yt_client, input_dir, date):
    prism_quality = []
    for metric_type in ('adv', 'gmv'):
        for operating_system in OPERATING_SYSTEMS:
            cost_cumsum = [0]
            for row in yt_client.read_table(
                    yt_client.TablePath(os.path.join(input_dir, metric_type), exact_key=operating_system)):
                cost_cumsum.append(row['cost_ratio_cumsum'])
            prism_quality.append({
                'fielddate': date,
                'os': operating_system,
                'metric_type': metric_type,
                'metric': float(np.trapz(sorted(cost_cumsum))),
            })

    for operating_system in OPERATING_SYSTEMS:
        metric = 0
        for row in yt_client.read_table(
                yt_client.TablePath(os.path.join(input_dir, 'share'), exact_key=operating_system)):
            if row['prism_segment'] in ('p4', 'p5'):
                metric += row['google_share'] / 2
            elif row['prism_segment'] in ('p1', 'p2', 'p3'):
                metric -= row['google_share'] / 3
        prism_quality.append({
            'fielddate': date,
            'os': operating_system,
            'metric_type': 'share',
            'metric': metric,
        })

    return prism_quality


def check_quality(yt_client, yql_client, tables_to_check, dates, output_dirs, transaction=None):
    """
    Function to check prism quality.
    Production metrics may be found at https://datalens.yandex-team.ru/uj0m556sqfoir-unifiedcryptastats?tab=81v.

    Parameters:
        yt_client: crypta/profile/utils/yt_utils.py
        yql_client: crypta/lib/python/yql/client/__init__.py
        tables_to_check: yt tables paths to check
            column names: yandexuid, cluster, prism_segment.
        dates: string date for which to take adv, gmv and share parsed logs
        output_dirs: yt directories path to save final data for metrics
        transaction: optionally passed external transaction
    """
    with yt_client.Transaction(transaction_id=transaction.transaction_id if transaction is not None else None) as transaction:
        prism_cost_adv_queries, prism_cost_gmv_queries, prism_share_queries = [], [], []
        for idx, (date, output_dir) in enumerate(zip(dates, output_dirs)):
            prism_cost_adv_queries.append(prism_cost_metrics_query.format(
                idx=str(idx),
                joined='prism_adv',
                sum_cost='sum_adv_cost',
                metric_input_table=os.path.join(config.PRISM_CHEVENT_LOG_DIRECTORY, date),
                output_table=os.path.join(output_dir, 'adv'),
            ))

            prism_cost_gmv_queries.append(prism_cost_metrics_query.format(
                idx=str(idx),
                joined='prism_gmv',
                sum_cost='sum_gmv_cost',
                metric_input_table=os.path.join(config.PRISM_GMV_DIRECTORY, date),
                output_table=os.path.join(output_dir, 'gmv'),
            ))

            prism_share_queries.append(prism_share_query.format(
                idx=str(idx),
                yandex_google_visits_table=os.path.join(config.PRISM_YANDEX_GOOGLE_VISITS_DIRECTORY, date),
                output_table=os.path.join(output_dir, 'share'),
            ))

        query = templater.render_template(
            prism_quality_query_template,
            vars={
                'input_tables': tables_to_check,
                'prism_cost_adv_queries': prism_cost_adv_queries,
                'prism_cost_gmv_queries': prism_cost_gmv_queries,
                'prism_share_queries': prism_share_queries,
            },
        )
        yql_client.execute(
            query=query,
            transaction=transaction.transaction_id,
            title='YQL Calculate prism quality metrics',
        )

        return [calculate_prism_quality(yt_client=yt_client, input_dir=out_dir, date=date)
                for date, out_dir in zip(dates, output_dirs)]
