import json

from crypta.lib.python import templater
from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.lib.python.yt import yt_helpers
from crypta.rt_socdem.lib.python.model.config import config


distribution_query = '''
{% for socdem_type in segment_names_by_socdem_type %}
{% if thresholds_name_by_socdem_type[socdem_type] in thresholds %}
$ratios_{{ socdem_type }} = (
    SELECT
        {% for segment_name in segment_names_by_socdem_type[socdem_type] %}
        `Probability:Class={{ loop.index - 1 }}` / {{ thresholds[thresholds_name_by_socdem_type[socdem_type]][loop.index - 1] }} AS `{{ segment_name }}`,
        {% endfor %}
    FROM `{{ predictions_on_pools_dir }}/{{ socdem_type }}`
);

$row_count_{{ socdem_type }} = (
    SELECT CAST(COUNT(*) AS Double) AS row_count
    FROM $ratios_{{ socdem_type }}
);

$distribution_{{socdem_type}} = (
    SELECT
        AsDict(
            {% for segment_name in segment_names_by_socdem_type[socdem_type] %}
            AsTuple('{{ segment_name }}', COUNT_IF(
                ratios.`{{ segment_name }}` >= 1
                {% for another_segment_name in segment_names_by_socdem_type[socdem_type] %}
                AND ratios.`{{ segment_name }}` >= ratios.`{{ another_segment_name }}`
                {% endfor %}
            ) / $row_count_{{ socdem_type }}),
            {% endfor %}
            AsTuple('UNKNOWN', COUNT_IF(
                {% for segment_name in segment_names_by_socdem_type[socdem_type] %}
                {% if loop.index > 1 %}
                AND
                {% endif %}
                ratios.`{{ segment_name }}` < 1
                {% endfor %}
            ) / $row_count_{{ socdem_type }}),
        ) AS distribution,
    FROM $ratios_{{ socdem_type }} AS ratios
);

$flatten_distribution_{{socdem_type}} = (
    SELECT
        segment_name_and_percentage.0 AS segment,
        segment_name_and_percentage.1 AS percentage,
        '{{socdem_type}}' AS socdem,
    FROM $distribution_{{socdem_type}}
    FLATTEN DICT BY distribution AS segment_name_and_percentage
);
{% endif %}
{% endfor %}

$all_distributions = (
{% for socdem_type in segment_names_by_socdem_type if thresholds_name_by_socdem_type[socdem_type] in thresholds %}
{% if loop.index > 1 %}
UNION ALL
{% endif %}
    SELECT *
    FROM $flatten_distribution_{{socdem_type}}
{% endfor %}
);

INSERT INTO `{{ output_table }}` WITH TRUNCATE
SELECT *
FROM $all_distributions;
'''


def estimate(
    yt_client,
    yql_client,
    transaction,
    predictions_on_pools_dir=config.THRESHOLDS_PREDICTIONS_ON_POOLS_DIR,
    thresholds=None,
):
    with yt_client.TempTable() as output_table:
        if thresholds is None:
            thresholds = json.loads(yt_client.read_file(config.THRESHOLDS_FILE).read())

        yql_client.execute(
            query=templater.render_template(
                template_text=distribution_query,
                vars={
                    'predictions_on_pools_dir': predictions_on_pools_dir,
                    'thresholds': thresholds,
                    'segment_names_by_socdem_type': config.SEGMENT_NAMES_BY_SOCDEM_TYPE,
                    'output_table': output_table,
                    'thresholds_name_by_socdem_type': config.THRESHOLDS_NAME_BY_SOCDEM_TYPE,
                },
            ),
            transaction=str(transaction.transaction_id),
            title='YQL Realtime socdem: Estimate predictions distribution',
        )

        distribution = [row for row in yt_client.read_table(output_table)]

        return distribution


def send_metrics(
    yt_client,
    yql_client,
    date_for_metrics,
):
    with NirvanaTransaction(yt_client) as transaction:
        yt_helpers.write_stats_to_yt(
            yt_client=yt_client,
            table_path=config.DATALENS_REALTIME_SOCDEM_DISTRIBUTIONS_TABLE,
            data_to_write=estimate(yt_client, yql_client, transaction),
            schema={
                'socdem': 'string',
                'segment': 'string',
                'percentage': 'double',
            },
            date=date_for_metrics,
        )
