from os.path import join

import numpy as np
import pandas as pd

from crypta.lib.python import classification_thresholds
from crypta.lib.python import templater
from crypta.profile.lib.socdem_helpers import socdem_config
from crypta.profile.utils.config import config

EARTH_REGION_ID = 0

histogram_query_template = """
$max_update_time = (
    SELECT MAX(update_time)
    FROM `{{profiles_table}}`
);

{% for socdem_segment_type in socdem_segments_types %}

$prepared_{{socdem_segment_type}} = (
    SELECT
        profiles.yandexuid AS yandexuid,
        {% for socdem_segment in socdem_segments_types[socdem_segment_type] %}
        Math::Round(Yson::ConvertToDoubleDict(profiles.{{socdem_segment_type}})['{{socdem_segment}}'], {{round_precision}}) AS `{{socdem_segment}}`,
        {% endfor %}
    FROM `{{profiles_table}}` AS profiles
    INNER JOIN `{{yuid_with_all_table}}` AS yuid_with_all
    USING(yandexuid)
    WHERE yuid_with_all.main_region_country == {{russia_region_id}} AND update_time == $max_update_time
);

INSERT INTO `{{output_table_dir}}/{{socdem_segment_type}}`
WITH TRUNCATE

SELECT
    `{{socdem_segments_types[socdem_segment_type]|join('`, `')}}`,
    COUNT(*) AS cnt
FROM $prepared_{{socdem_segment_type}}
WHERE `{{socdem_segments_types[socdem_segment_type]|join('` + `')}}` == 1
GROUP BY `{{socdem_segments_types[socdem_segment_type]|join('`, `')}}`
ORDER BY `{{socdem_segments_types[socdem_segment_type]|join('`, `')}}`;

{% endfor %}
"""


socdem_distributions_query_template = """
{% if use_last_active_profiles %}
$max_update_time = SELECT MAX(update_time) FROM `{{profiles_table}}`;
{% endif %}

$profiles_socdem = (
    SELECT
        profiles.yandexuid AS yandexuid,
        Yson::ConvertToStringDict(profiles.exact_socdem)['gender'] AS gender,
        Yson::ConvertToStringDict(profiles.exact_socdem)['age_segment'] AS age,
        Yson::ConvertToStringDict(profiles.exact_socdem)['income_5_segment'] AS income
    FROM `{{profiles_table}}` AS profiles
    INNER JOIN `{{yuid_with_all_by_yandexuid}}` AS yuid_with_all
    USING(yandexuid)
    WHERE Geo::IsRegionInRegion(CAST(yuid_with_all.main_region_country AS Int32), {{country_region_id}})
    {% if use_last_active_profiles %}
        AND update_time >= $max_update_time
    {% endif %}
);

$total_profiles_cnt = SELECT CAST(COUNT(*) AS Double) FROM $profiles_socdem;

{% for socdem_type in socdem_types %}
${{socdem_type}}_stats = (
    SELECT
        {{socdem_type}},
        COUNT(*) / $total_profiles_cnt AS p_cnt
    FROM $profiles_socdem
    GROUP BY {{socdem_type}}
);
{% endfor %}

INSERT INTO `{{distributions_table}}`
WITH TRUNCATE

{% for socdem_type in socdem_types %}
{% if loop.index != 1 %}
    UNION ALL
{% endif %}
SELECT
    '{{socdem_type}}' AS socdem_type,
    Yson::Serialize(Yson::FromDoubleDict(
        CAST(ToDict(AggregateList(AsTuple({{socdem_type}}, p_cnt))) AS Dict<String, Double>)
    )) AS distribution
FROM ${{socdem_type}}_stats
{% endfor %}
ORDER BY socdem_type;
"""


def calculate_thresholds(
    yt_client,
    yql_client,
    profiles_table,
    histograms_directory,
    needed_recalls,
    needed_total_recalls=None,
):
    """
    Function to calculate thresholds with given socdem predictions (exact_socdem).

    Parameters
    ----------
    yt_client
        crypta/lib/python/yt/client
    yql_client
        crypta/lib/python/yql/client
    profiles_table: str
        Table with [yandexuid, 'gender', 'user_age_6s', 'income_5_segments'] columns
        that will be used to calculate thresholds.
    histograms_directory: str
        Directory to save tables with histograms for selected socdem types.
    needed_recalls: dict
        Possible keys are ('gender', 'user_age_6s', 'income_5_segments'),
        Value need to be a list that contains needed recalls for classes.
    needed_total_recalls: dict, optional
        Possible keys are ('gender', 'user_age_6s', 'income_5_segments'),
        Value need to represent overall recall (defines level of uncertainty).

    Returns
    -------
    dict
        Possible keys are ('gender', 'user_age_6s', 'income_5_segments'),
        Value is a list of thresholds.
    """
    assert len(needed_recalls) > 0, 'Needed recalls must be specified.'
    for socdem_segment_type, recalls_values in needed_recalls.items():
        assert np.sum(recalls_values) == 1., 'Sum of the needed recalls must be 1.'
        if socdem_segment_type in needed_total_recalls:
            assert 0 <= needed_total_recalls[socdem_segment_type] <= 1, 'Total recall must be in range [0, 1].'

    needed_total_recalls = needed_total_recalls or {}
    segment_names_by_label_type = {
        socdem_segment_type: segments_names
        for socdem_segment_type, segments_names in socdem_config.yet_another_segment_names_by_label_type.items()
        if socdem_segment_type in needed_recalls
    }

    with yt_client.Transaction() as transaction:
        histogram_query = templater.render_template(
            template_text=histogram_query_template,
            vars={
                'socdem_segments_types': segment_names_by_label_type,
                'round_precision': -2,
                'russia_region_id': config.RUSSIA_REGION_ID,
                'profiles_table': profiles_table,
                'yuid_with_all_table': config.YUID_WITH_ALL_BY_YANDEXUID_TABLE,
                'output_table_dir': histograms_directory,
            },
        )
        yql_client.execute(
            query=histogram_query,
            title='YQL get histograms to calculate classification thresholds',
            transaction=str(transaction.transaction_id),
        )

    thresholds = {}
    for socdem_segment_type in segment_names_by_label_type:
        thresholds[socdem_segment_type] = classification_thresholds.find_thresholds(
            table=pd.DataFrame(yt_client.read_table(join(histograms_directory, socdem_segment_type))),
            segments=segment_names_by_label_type[socdem_segment_type],
            needed_recalls=np.array(needed_recalls[socdem_segment_type]),
            needed_recall=needed_total_recalls.get(socdem_segment_type, 1.),
        )

    return thresholds


def compute_socdem_distributions(
    yql_client,
    profiles_table,
    distributions_table,
    country_region_id=config.RUSSIA_REGION_ID,
    use_last_active_profiles=True,
):
    """
    Parameters
    ----------
    yql_client
        crypta/lib/python/yql/client
    profiles_table: str
        Table with [yandexuid, exact_socdem] columns that will be used to calculate distributions.
    distributions_table: str
        Table that will be used to save results.
    country_region_id: int, optional
        Country ID that will be used to filter profiles.
        Use default value for Russia, None for all countries.
    use_last_active_profiles: bool, optional
        If true, only active profiles will be considered to calculate distributions.
    Returns
    -------
    None
        distributions_table will be created.
    """
    socdem_distributions_query = templater.render_template(
        template_text=socdem_distributions_query_template,
        vars={
            'profiles_table': profiles_table,
            'yuid_with_all_by_yandexuid': config.YUID_WITH_ALL_BY_YANDEXUID_TABLE,
            'country_region_id': country_region_id or EARTH_REGION_ID,
            'socdem_types': socdem_config.SOCDEM_TYPES,
            'distributions_table': distributions_table,
            'use_last_active_profiles': use_last_active_profiles,
        },
    )
    yql_client.execute(
        query=socdem_distributions_query,
        title='YQL Compute socdem distributions',
    )
