import json
import numpy as np
import os
import pandas as pd

from yt import wrapper as yt_wrapper

from crypta.lib.python import (
    classification_thresholds,
    templater,
)
from crypta.lib.python.bigb_catboost_applier import (
    features_mapping as features_mapping_utils,
    train_sample as train_sample_utils,
)
from crypta.lib.python.yt import yt_helpers
from crypta.lib.python.yt.yt_helpers import tempdir
from crypta.rt_socdem.lib.python.model import (
    fields,
    utils,
)
from crypta.rt_socdem.lib.python.model.config import config


pool_query = '''
PRAGMA AnsiInForEmptyOrNullableItemsCollections;

PRAGMA File('bigb.so', '{{ bigb_udf_url }}');
PRAGMA udf('bigb.so');

$date_end = '{{ date }}';
$date_begin = CAST(CAST($date_end AS Date) - DateTime::IntervalFromDays({{ sampling_period_days }}) AS String);

$filter_pool_tables = ($table_name) -> {
    $date = String::SplitToList($table_name, 'T')[0];
    RETURN $table_name LIKE '%{{ beh_hit_log_table_suffix }}'
        AND $date > $date_begin
        AND $date <= $date_end;
};

$country_id = ($region_id) -> {RETURN Geo::RoundRegionById(CAST($region_id AS Int32), "country").id};

{% for socdem_type in socdem_types %}

{% for socdem_mode in socdem_keywords %}
$filter_{{ socdem_mode }}_{{ socdem_type }} = ($item) -> {
    RETURN $item.keyword_id = {{ socdem_keywords[socdem_mode][socdem_type] }};
};
{% endfor %}

$profiles_{{ socdem_type }} = (
    SELECT
        beh.profile AS profile,
        chevent.cryptaidv2 AS crypta_id,
        chevent.uniqid AS yandexuid,
    FROM RANGE(
        '{{ chevent_log_dir }}',
        $date_begin,
        $date_end,
    ) AS chevent
    LEFT JOIN ANY (
        SELECT
            ProfileDump AS profile,
            Bigb::ParseProfile(ProfileDump) AS parsed_profile,
            HitLogID AS hitlogid,
        FROM FILTER(
            '{{ beh_hit_log_dir }}',
            $filter_pool_tables,
        ) SAMPLE {{ bigb_sampling_rate }}
    ) AS beh
    USING(hitlogid)
    WHERE beh.profile IS NOT NULL
        AND ListLength(ListFilter(parsed_profile.items, $filter_offline_{{ socdem_type }})) == 0
        AND ListLength(ListFilter(parsed_profile.items, $filter_realtime_{{ socdem_type }})) > 0
        AND chevent.fraudbits == 0
        AND chevent.placeid in (542, 1542)
        AND $country_id(chevent.regionid) == {{ russia_region_id }}
    LIMIT {{ max_thresholds_pool_size }}
);


INSERT INTO `{{ prepared_for_mapper_dir }}/{{ socdem_type }}`
WITH TRUNCATE

SELECT
    yandexuid,
    crypta_id,
    profile,
    0 AS {{ socdem_type }}, -- dummy values to run catboost
    1.0 AS {{ socdem_type }}_weight,
FROM $profiles_{{ socdem_type }};

{% endfor %}
'''

histogram_query = '''
{% for socdem_type in socdem_types %}
$prepared_{{ socdem_type }} = (
    SELECT
        {% for segment_name in segment_names_by_socdem_type[socdem_type] %}
        Math::Round(`Probability:Class={{ loop.index - 1 }}`, -{{ round_precision }}) AS `{{ segment_name }}`,
        {% endfor %}
    FROM `{{ predictions_on_pools_dir }}/{{ socdem_type }}`
);

{% set segment_names_list = '`' + (segment_names_by_socdem_type[socdem_type] | join('`, `')) + '`' %}

INSERT INTO `{{ histograms_dir }}/{{ socdem_type }}`
WITH TRUNCATE

SELECT
    {{ segment_names_list }},
    COUNT(*) AS cnt,
FROM $prepared_{{ socdem_type }}
GROUP BY {{ segment_names_list }}
ORDER BY {{ segment_names_list }};
{% endfor %}
'''


def get_pool(
    yt_client,
    yql_client,
    yesterday,
    pools_dir=config.THRESHOLDS_POOLS_DIR,
    features_mapping_table_path=config.FEATURES_MAPPING_TABLE,
    socdem_types=config.SOCDEM_TYPES,
):
    with tempdir.YtTempDir(yt_client, config.COMMON_TMP_DIRECTORY) as prepared_for_mapper_dir:
        beh_hit_log_dir, beh_hit_log_table_suffix = config.BEH_HIT_LOG_TABLE.split('/{}')
        yql_client.execute(
            query=templater.render_template(
                template_text=pool_query,
                vars={
                    'prepared_for_mapper_dir': prepared_for_mapper_dir.path,
                    'socdem_types': socdem_types,
                    'chevent_log_dir': config.CHEVENT_LOG_DIR,
                    'beh_hit_log_dir': beh_hit_log_dir,
                    'beh_hit_log_table_suffix': beh_hit_log_table_suffix,
                    'socdem_keywords': config.SOCDEM_CLASS_KEYWORDS,
                    'russia_region_id': config.RUSSIA_REGION_ID,
                    'bigb_udf_url': config.BIGB_UDF_URL,
                    'bigb_sampling_rate': config.BIGB_SAMPLING_RATE,
                    'max_thresholds_pool_size': config.MAX_THRESHOLDS_POOL_SIZE,
                    'sampling_period_days': config.BIGB_SAMPLING_PERIOD_DAYS,
                    'date': yesterday,
                },
            ),
            transaction=str(prepared_for_mapper_dir.transaction.transaction_id),
            title='YQL Realtime socdem: Get pool for thresholds',
        )

        features_mapping, float_features_description = features_mapping_utils.get_features_mapping(
            yt_client=yt_client,
            features_mapping_table_path=features_mapping_table_path,
        )

        with yt_wrapper.OperationsTracker() as tracker:
            for socdem_type in socdem_types:
                output_table = os.path.join(pools_dir, socdem_type)

                yt_helpers.create_empty_table(
                    yt_client=yt_client,
                    path=output_table,
                    schema=[
                        {'name': 'key', 'type': 'string'},
                        {'name': 'value', 'type': 'string'},
                    ],
                )

                map_operation = yt_client.run_map(
                    train_sample_utils.PrepareSamplesMapper(
                        features_mapping=features_mapping,
                        counters_to_features=utils.COUNTERS_TO_FEATURES,
                        keywords_to_features=utils.KEYWORDS_TO_FEATURES,
                        key_column=fields.YANDEXUID,
                        target_column=socdem_type,
                        weight_column='{}_weight'.format(socdem_type),
                        split=False,
                    ),
                    os.path.join(prepared_for_mapper_dir.path, socdem_type),
                    output_table,
                    sync=False,
                )

                tracker.add(map_operation)


def find(
    yt_client,
    yql_client,
    date_for_metrics=None,
    predictions_on_pools_dir=config.THRESHOLDS_PREDICTIONS_ON_POOLS_DIR,
    socdem_types=config.SOCDEM_TYPES,
    round_precision=config.THRESHOLDS_ROUND_PRECISION,
    needed_recalls=config.THRESHOLDS_NEEDED_RECALLS,
    json_local_file_output=None,
    json_yt_output=config.THRESHOLDS_FILE,
):
    with tempdir.YtTempDir(yt_client, config.COMMON_TMP_DIRECTORY) as histograms_dir:
        yql_client.execute(
            query=templater.render_template(
                template_text=histogram_query,
                vars={
                    'round_precision': round_precision,
                    'socdem_types': socdem_types,
                    'segment_names_by_socdem_type': config.SEGMENT_NAMES_BY_SOCDEM_TYPE,
                    'predictions_on_pools_dir': predictions_on_pools_dir,
                    'histograms_dir': histograms_dir.path,
                },
            ),
            transaction=str(histograms_dir.transaction.transaction_id),
            title='YQL Realtime socdem: Get histograms for thresholds',
        )

        thresholds = {}
        for socdem_type in socdem_types:
            thresholds[config.THRESHOLDS_NAME_BY_SOCDEM_TYPE[socdem_type]] = classification_thresholds.find_thresholds(
                table=pd.DataFrame(yt_client.read_table(os.path.join(histograms_dir.path, socdem_type))),
                segments=config.SEGMENT_NAMES_BY_SOCDEM_TYPE[socdem_type],
                needed_recalls=np.array(needed_recalls[socdem_type]),
                constant_for_full_coverage=config.THRESHOLDS_CONSTANT_FOR_FULL_COVERAGE,
            )

        if date_for_metrics is not None:
            yt_helpers.write_stats_to_yt(
                yt_client=yt_client,
                table_path=config.DATALENS_REALTIME_SOCDEM_CLASSIFICATION_THRESHOLDS_TABLE,
                data_to_write=[{
                    'socdem': socdem_type,
                    'segment': segment,
                    'threshold': threshold,
                } for socdem_type in config.SEGMENT_NAMES_BY_SOCDEM_TYPE for segment, threshold in zip(
                    config.SEGMENT_NAMES_BY_SOCDEM_TYPE[socdem_type],
                    thresholds[config.THRESHOLDS_NAME_BY_SOCDEM_TYPE[socdem_type]],
                )],
                schema={
                    'socdem': 'string',
                    'segment': 'string',
                    'threshold': 'double',
                },
                date=date_for_metrics,
            )

        if json_local_file_output is not None:
            with open(json_local_file_output, 'w') as f:
                f.write(json.dumps(thresholds))

            with open(json_local_file_output, 'rb') as f:
                yt_client.write_file(json_yt_output, f)

        return thresholds
