import collections
from functools import partial
import os

from yt.wrapper import with_context

from crypta.lib.python import templater
from crypta.lib.python.custom_ml import training_config
from crypta.lib.python.custom_ml.tools import (
    application_utils,
    training_utils,
)
from crypta.lib.python.yt import yt_helpers

default_percentiles = collections.OrderedDict([
    ('90_100', (0.0, 0.1)),
    ('80_90', (0.1, 0.2)),
    ('70_80', (0.2, 0.3)),
    ('60_70', (0.3, 0.4)),
    ('50_60', (0.4, 0.5)),
    ('40_50', (0.5, 0.6)),
    ('30_40', (0.6, 0.7)),
    ('20_30', (0.7, 0.8)),
    ('10_20', (0.8, 0.9)),
    ('0_10', (0.9, 1.0)),
])

apply_catboost_classification_query = """
$schema = AsList({schema});
$exponentiation = ($raw_weights) -> {{RETURN ListMap($raw_weights, ($weight) -> {{ RETURN Math::Exp($weight); }})}};
$normalize = ($weights) -> {{RETURN ListMap($weights, ($weight) -> {{ RETURN $weight / ListSum($weights); }})}};

INSERT INTO `{output_table}`
WITH TRUNCATE

SELECT
    Cast(PassThrough AS Uint64) AS id,
    'yandexuid' AS id_type,
    Yson::Serialize(Yson::FromDoubleDict(ToDict(
        ListZip(
            $schema,
            $normalize($exponentiation(Result))
        )
    ))) AS model_predictions
FROM $processed;
"""

modeled_segments_query_template = """
{% for segment_parameters in parameters %}

INSERT INTO `{{output_directory}}/modeled_{{segment_parameters.target_type}}`
WITH TRUNCATE

SELECT
    id AS yandexuid,
    probability,
FROM `{{input_table}}`
{% if segment_parameters.target_type == 'positive' %}
ORDER BY probability DESC
{% elif segment_parameters.target_type == 'negative' %}
ORDER BY probability
{% endif %}
LIMIT {{segment_parameters.users_number}};

{% endfor %}
"""

initial_segments_query_template = """
{% for target_type in target_types %}

INSERT INTO `{{output_directory}}/initial_{{target_type}}`
WITH TRUNCATE

SELECT
    CAST(yandexuid AS String) AS yandexuid,
FROM `{{input_table}}`
WHERE segment_name == '{{target_type}}'
LIMIT {{users_max_number}};
{% endfor %}
"""

get_currently_computed_segments_query = """
INSERT INTO `{output_table}`
WITH TRUNCATE

SELECT DISTINCT segment_name
FROM `{input_table}`;
"""


def voting_by_cryptaid(key, rows):
    result_probabilities = collections.Counter()
    rows = list(rows)

    for row in rows:
        result_probabilities += collections.Counter(row['model_predictions'])

    for row in rows:
        yield {
            'id': row['id'],
            'id_type': row['id_type'],
            'model_predictions': training_utils.normalize_probabilities(result_probabilities),
        }


def extract_probability_for_segmentation(row, positive_name):
    yield {
        'id': row['id'],
        'id_type': row['id_type'],
        'probability': row['model_predictions'][positive_name],
    }


def extract_integral_score_for_segmentation(row, class_order):
    number_of_classes = len(class_order)
    integral = sum([
        (number_of_classes - num) * row['model_predictions'][class_name] for num, class_name in enumerate(class_order)
    ])
    yield {
        'id': row['id'],
        'id_type': row['id_type'],
        'probability': (number_of_classes - integral) / (number_of_classes - 1),
    }


def get_most_appropriate_segment_with_probability(row, thresholds, segments):
    if thresholds:
        segment = None
        value_against_threshold_ratios = collections.defaultdict(dict)
        for current_segment, value in row['model_predictions'].items():
            if current_segment in segments:
                ratio = value / thresholds[current_segment]
                if ratio > 1.0:
                    value_against_threshold_ratios[current_segment] = ratio
                if value_against_threshold_ratios:
                    segment, _ = training_utils.get_item_with_max_value(value_against_threshold_ratios)

        if segment is not None:
            yield {
                'id': str(row['id']),
                'id_type': row['id_type'],
                'segment_name': segment,
                'probability': row['model_predictions'][segment],
            }
    else:
        for segment, value in row['model_predictions'].items():
            if segment in segments:
                yield {
                    'id': str(row['id']),
                    'id_type': row['id_type'],
                    'segment_name': segment,
                    'probability': row['model_predictions'][segment],
                }


@with_context
class GetSegmentWithPercentile(object):
    def __init__(self, slice_to_segment_name_dict):
        self.slice_to_segment_name_dict = slice_to_segment_name_dict

    def __call__(self, row, context):
        sub_prob_flag = len(self.slice_to_segment_name_dict) == 2 and context.table_index == 0
        yield {
            'id': str(row['id']),
            'id_type': row['id_type'],
            'probability': round(1. - row['probability'] if sub_prob_flag else row['probability'], 6),
            'segment_name': self.slice_to_segment_name_dict[context.table_index],
        }


def apply_model_to_profiles(
    yt_client,
    yql_client,
    output_path,
    yt_model_path,
    ordered_thresholds=None,
    percentiles=default_percentiles,
):
    with yt_client.TempTable() as catboost_applied, \
            yt_client.TempTable() as catboost_applied_with_cryptaid, \
            yt_client.TempTable() as catboost_applied_without_cryptaid, \
            yt_client.TempTable() as catboost_voted_by_cryptaid, \
            yt_client.TempTable() as for_percentile_segmentation:
        yql_client.execute(
            query=''.join((
                application_utils.apply_catboost_common_query.format(
                    number_of_classes=2,
                    catboost_model=training_utils.get_model_tag(yt_path=yt_model_path),
                    input_table=training_config.CATBOOST_FEATURES,
                ),
                apply_catboost_classification_query.format(
                    schema="'negative', 'positive'",
                    output_table=catboost_applied,
                ),
            )),
            title='YQL Apply test model to profiles'
        )

        yql_client.execute(
            query=application_utils.split_by_cryptaid_presence_query.format(
                catboost_applied_table=catboost_applied,
                yandexuid_cryptaid_table=training_config.YANDEXUID_CRYPTAID_TABLE,
                output_with_cryptaid_table=catboost_applied_with_cryptaid,
                output_without_cryptaid_table=catboost_applied_without_cryptaid,
            ),
            title='YQL Split by crypta_id presence query'
        )

        yt_client.run_reduce(
            voting_by_cryptaid,
            catboost_applied_with_cryptaid,
            catboost_voted_by_cryptaid,
            reduce_by='crypta_id',
        )

        yt_helpers.create_empty_table(
            yt_client=yt_client,
            path=output_path,
            schema={
                'id': 'string',
                'id_type': 'string',
                'segment_name': 'string',
                'probability': 'double',
            },
            additional_attributes={'optimize_for': 'scan'},
            force=True,
        )

        if percentiles is not None:
            yt_client.run_map(
                partial(extract_probability_for_segmentation, positive_name='positive'),
                [
                    catboost_applied_without_cryptaid,
                    catboost_voted_by_cryptaid,
                ],
                for_percentile_segmentation,
            )

            yt_client.run_sort(for_percentile_segmentation, sort_by=['probability', 'id'])

            yt_client.run_map(
                GetSegmentWithPercentile(slice_to_segment_name_dict=list(percentiles.keys())),
                application_utils.get_slices_for_percentile_segmentation(
                    yt=yt_client,
                    table=for_percentile_segmentation,
                    percentiles=percentiles,
                ),
                output_path,
            )
        else:
            yt_client.run_map(
                partial(
                    get_most_appropriate_segment_with_probability,
                    thresholds=ordered_thresholds,
                    segments=['positive', 'negative'],
                ),
                [
                    catboost_applied_without_cryptaid,
                    catboost_voted_by_cryptaid,
                ],
                output_path,
            )

        return output_path


def get_modeled_segments(
    yql_client,
    predictions_table_path,
    positive_output_segment_size,
    negative_output_segment_size,
):
    output_dir = os.path.dirname(predictions_table_path)

    get_segments_query = templater.render_template(
        template_text=modeled_segments_query_template,
        vars={
            'input_table': predictions_table_path,
            'parameters': [
                {
                    'target_type': 'positive',
                    'users_number': positive_output_segment_size,
                },
                {
                    'target_type': 'negative',
                    'users_number': negative_output_segment_size,
                },
            ],
            'output_directory': os.path.join(output_dir, 'segments'),
        },
    )

    yql_client.execute(
        query=get_segments_query,
        title='YQL get modeled positive and negative segments'
    )


def get_initial_segments(yql_client, train_sample_by_yuid_table):
    output_dir = os.path.dirname(train_sample_by_yuid_table)

    get_segments_query = templater.render_template(
        template_text=initial_segments_query_template,
        vars={
            'input_table': train_sample_by_yuid_table,
            'target_types': training_config.labels,
            'output_directory': os.path.join(output_dir, 'segments'),
            'users_max_number': training_config.MAX_SEGMENT_SIZE,
        },
    )

    yql_client.execute(
        query=get_segments_query,
        title='YQL get initial positive and negative segments'
    )


def get_percentiles_for_currently_computed_segments(yt_client, yql_client, segments_table):
    with yt_client.TempTable() as segment_names_table:
        yql_client.execute(
            query=get_currently_computed_segments_query.format(
                input_table=segments_table,
                output_table=segment_names_table,
            ),
            title='YQL get currently computed segments for model',
        )

        segments = [row['segment_name'].split('_')[-2:] for row in yt_client.read_table(segment_names_table)]
        segments = sorted(segments, reverse=True)

        percentiles = collections.OrderedDict()
        for lower_limit, upper_limit in segments:
            percentiles['{}_{}'.format(lower_limit, upper_limit)] = (
                1. - float(upper_limit) / 100,
                1. - float(lower_limit) / 100,
            )

        return percentiles
