#!/usr/bin/env python
# -*- coding: utf-8 -*-
from collections import defaultdict

import luigi
from yt.wrapper import create_table_switch, with_context

from crypta.lib.python.yt import (
    schema_utils,
    yt_helpers,
)
from crypta.profile.lib.socdem_helpers import socdem_config
from crypta.profile.utils.config import config
from crypta.profile.utils import loggers
from crypta.profile.utils.luigi_utils import (
    BaseYtTask,
    ExternalInput,
    YtDailyRewritableTarget,
)


def calculate_trust_score(category_scores):
    if not category_scores:
        return {'label': None, 'absolute_score': 0.0, 'relative_score': 0.0}

    label, label_absolute_score = max(category_scores.items(), key=lambda x: x[1])
    sum_category_score = sum(category_scores.values())
    label_relative_score = round(float(label_absolute_score) / sum_category_score, 2)

    return {
        'label': label,
        'absolute_score': round(label_absolute_score, 2),
        'relative_score': label_relative_score,
    }


@with_context
class FilterSocdemStorage(object):
    """
    Maps socdem storage to 3 tables:
    - Socdem labels for correction (socdem_labels_output)
    - Strong labels used for training model (socdem_labels_for_learning_output)
    - Weak labels from underrepresented countries used for sampling for training (socdem_labels_for_sampling_output)

    All labels satisfying weak thresholds goes to correction table

    Labels goes directly to training table if one of following is true:
    - Label scores satisfies strong thresholds
    - Corresponding user is from country with underrepresented income labels
        and the user has income label satisfying weak threshold

    Labels goes to sampling for training table if all of following is true:
    - Label does not go directly to training table
    - Corresponding user from country with underrepresented gender or age labels
    """

    yandexuid_with_all_input = 0
    socdem_storage_input = 1
    socdem_labels_output = 0
    socdem_labels_for_learning_output = 1
    socdem_labels_for_sampling_output = 2

    underrepresented_countries = {
        'gender': {983},
        'age': {983},
        'income': {149, 159, 187, 983},
    }

    country_id_to_name = {
        225: 'ru',
        187: 'ua',
        983: 'tr',
        149: 'by',
        159: 'kz',
    }

    def __init__(
        self,
        absolute_thresholds=config.SOCDEM_LABELING_ABSOLUTE_THRESHOLDS,
        relative_threshold=config.RELATIVE_THRESHOLD,
    ):
        self.absolute_thresholds = absolute_thresholds
        self.relative_threshold = relative_threshold

    def __call__(self, key, rows, context):
        main_region_country = None
        activity_type = None

        for record in rows:
            if context.table_index == self.yandexuid_with_all_input:
                activity_type = record['ip_activity_type']
                main_region_country = record['main_region_country']
            elif context.table_index == self.socdem_storage_input and activity_type is not None:
                socdem_records = defaultdict(dict)
                scores = defaultdict(dict)

                for socdem_type in socdem_config.SOCDEM_TYPES:
                    scores[socdem_type] = calculate_trust_score(record[socdem_type + '_scores'])

                for socdem_type, socdem_segment_type in zip(socdem_config.SOCDEM_TYPES, socdem_config.SOCDEM_SEGMENT_TYPES):
                    for threshold_level in ['strong', 'weak']:
                        if scores[socdem_type]['absolute_score'] >= self.absolute_thresholds[threshold_level][socdem_type] \
                                and scores[socdem_type]['relative_score'] >= self.relative_threshold:
                            socdem_records[threshold_level][socdem_segment_type] = scores[socdem_type]['label']

                if activity_type == 'active' and main_region_country is not None:
                    training_output_record = socdem_records['strong']
                    has_any_weak_label = False
                    has_weak_income = False

                    for socdem_type, socdem_segment_type in zip(socdem_config.SOCDEM_TYPES, socdem_config.SOCDEM_SEGMENT_TYPES):
                        if socdem_records['weak'].get(socdem_segment_type) and main_region_country in self.underrepresented_countries[socdem_type]:
                            has_any_weak_label = True
                            has_weak_income = has_weak_income or socdem_type == 'income'
                            training_output_record[socdem_segment_type] = socdem_records['weak'][socdem_segment_type]

                    if training_output_record:
                        training_output_record.update({
                            'yandexuid': record['yandexuid'],
                            'main_region_country': main_region_country,
                            'region': self.country_id_to_name.get(main_region_country, None),
                        })
                        if not has_any_weak_label or has_weak_income:
                            yield create_table_switch(self.socdem_labels_for_learning_output)
                        else:
                            yield create_table_switch(self.socdem_labels_for_sampling_output)
                        yield training_output_record

                if socdem_records['weak']:
                    correction_output_record = socdem_records['weak']
                    correction_output_record.update({'yandexuid': record['yandexuid']})
                    yield create_table_switch(self.socdem_labels_output)
                    yield correction_output_record


@with_context
def join_with_active_yandexuid(key, rows, context):
    yuid_with_all_row = None

    for row in rows:
        if context.table_index == 0:
            yuid_with_all_row = row
        elif context.table_index == 1 and yuid_with_all_row:
            yield row


class GetLabeledSocdem(BaseYtTask):
    date = luigi.Parameter()
    priority = 100
    task_group = 'import_socdem_data'

    def requires(self):
        return ExternalInput(config.YANDEXUID_SOCDEM_STORAGE_TABLE)

    def output(self):
        return {
            'socdem_labels_for_learning': YtDailyRewritableTarget(config.SOCDEM_LABELS_FOR_LEARNING_TABLE, self.date),
            'socdem_labels': YtDailyRewritableTarget(config.SOCDEM_LABELS_TABLE, self.date),
        }

    @staticmethod
    def _socdem_labels_for_learning_schema():
        return {
            'yandexuid': 'uint64',
            'gender': 'string',
            'age_segment': 'string',
            'income_segment': 'string',
            'main_region_country': 'uint64',
            'region': 'string',
        }

    @staticmethod
    def _socdem_labels_schema():
        return {
            'yandexuid': 'uint64',
            'gender': 'string',
            'age_segment': 'string',
            'income_segment': 'string',
        }

    def run(self):
        with loggers.TimeTracker(monitoring_name=self.__class__.__name__), \
                self.yt.Transaction(), \
                self.yt.TempTable(attributes={
                    'schema': schema_utils.yt_schema_from_dict(self._socdem_labels_for_learning_schema())
                }) as socdem_labels_for_sampling:

            self.yt.create_empty_table(
                self.output()['socdem_labels_for_learning'].table,
                schema=self._socdem_labels_for_learning_schema(),
            )

            self.yt.create_empty_table(
                self.output()['socdem_labels'].table,
                schema=self._socdem_labels_schema(),
            )

            self.yt.run_reduce(
                FilterSocdemStorage(),
                [
                    config.YUID_WITH_ALL_BY_YANDEXUID_TABLE,
                    self.input().table,
                ],
                [
                    self.output()['socdem_labels'].table,
                    self.output()['socdem_labels_for_learning'].table,
                    socdem_labels_for_sampling,
                ],
                reduce_by='yandexuid',
            )

            yt_helpers.make_sample_with_size(
                yt_client=self.yt,
                source_table=socdem_labels_for_sampling,
                destination_table=socdem_labels_for_sampling,
                size=config.SAMPLED_LABELS_FOR_TRAINING_COUNT,
            )

            self.yt.run_merge(
                [
                    self.output()['socdem_labels_for_learning'].table,
                    socdem_labels_for_sampling,
                ],
                self.output()['socdem_labels_for_learning'].table,
                mode='unordered',
            )

            self.yt.run_sort(self.output()['socdem_labels_for_learning'].table, sort_by='yandexuid')
            self.yt.run_sort(self.output()['socdem_labels'].table, sort_by='yandexuid')

            self.yt.set_attribute(self.output()['socdem_labels_for_learning'].table, 'generate_date', self.date)
            self.yt.set_attribute(self.output()['socdem_labels'].table, 'generate_date', self.date)
