#!/usr/bin/env python
# -*- coding: utf-8 -*-

from collections import OrderedDict

from cached_property import cached_property

from crypta.lib.python.custom_ml.classification import (
    CustomClassificationModelTrainHelper,
    CustomClassificationParams,
)
from crypta.lib.python.custom_ml.classification.model_training import make_common_classification_train_sample_query
from crypta.lib.python.custom_ml.tools.training_utils import get_model_tag
from crypta.profile.runners.segments.lib.coded_segments.ml_tools.utils import resources
from crypta.profile.utils.config import config
from crypta.profile.utils.utils import revert_dict


additional_conditions = """
INNER JOIN `{user_data_table}` AS user_cond
ON indevice_yandexuid.yandexuid == CAST(user_cond.yuid AS Uint64)
WHERE (segment_name == 'mammies' AND user_cond.Attributes.Gender == 2 AND user_cond.Attributes.Age in (3, 4, 5, 6)) OR
    (segment_name == 'regional_youth' AND user_cond.Attributes.Age in (1, 2)) OR
    (segment_name == 'men_artisans' AND user_cond.Attributes.Gender == 1 AND
        user_cond.Attributes.Age in (3, 4, 5, 6)) OR
    (segment_name == 'advanced_youth' AND user_cond.Attributes.Age in (1, 2)) OR
    (segment_name == 'active_women' AND user_cond.Attributes.Gender == 2 AND
        user_cond.Attributes.Age in (3, 4, 5, 6)) OR
    (segment_name == 'successful_men' AND user_cond.Attributes.Gender == 1 AND
        user_cond.Attributes.Age in (3, 4, 5, 6)) OR
    (segment_name == 'age_conservatives' AND user_cond.Attributes.Age in (5, 6))
""".format(user_data_table=config.USER_DATA_TABLE)


class MarketModelTrainHelper(CustomClassificationModelTrainHelper):
    def __init__(self, yt, logger, date, storage_sample_path, train_sample_path):
        super(MarketModelTrainHelper, self).__init__(yt=yt, logger=logger, date=date)
        self.storage_sample_path = storage_sample_path
        self.train_sample_path = train_sample_path
        self.model_name = 'market'

    @cached_property
    def model_params(self):
        ordered_thresholds = OrderedDict([
            ('mammies', 0.32),
            ('regional_youth', 0.12),
            ('men_artisans', 0.31),
            ('advanced_youth', 0.5),
            ('active_women', 0.4),
            ('successful_men', 0.32),
            ('age_conservatives', 0.18),
        ])
        segment_name_to_id = self.get_segment_name_to_id(ordered_thresholds)
        self.segment_id_to_name = revert_dict(segment_name_to_id)
        resource_type = resources[self.model_name]

        return CustomClassificationParams(
            train_sample_path=self.train_sample_path,
            resource_type=resource_type,
            model_tag=get_model_tag(resource_type),
            metrics_group_name=self.get_metrics_group_name(self.model_name),
            model_description_in_sandbox='Model for market segments classification',
            segment_id_to_name=self.segment_id_to_name,
            segment_name_to_id=segment_name_to_id,
            ordered_thresholds=ordered_thresholds,
            make_train_sample_query=make_common_classification_train_sample_query.format(
                input_table=self.storage_sample_path,
                output_table=self.train_sample_path,
                additional_conditions=additional_conditions,
            ),
        )
