import numpy as np
from yt.wrapper import with_context, create_table_switch

from crypta.profile.lib.socdem_helpers import socdem_config
from crypta.profile.lib.socdem_helpers.tools.features import (
    cat_feature_types,
    categorical_feature_name_to_keyword,
    get_feature_name,
    MakeCatboostTrainingFeaturesBase,
)

MIN_SEGMENTS_COUNT = 10  # https://st.yandex-team.ru/CRYPTAUP-1507#5d1afb78501fd0001db8ef06
raw_yandexuid_profiles_index = 0
catboost_features_index = 1
weighted_socdem_sample_schema = {
    'yandexuid': 'uint64',
    'crypta_id': 'uint64',
    'gender': 'string',
    'age_segment': 'string',
    'income_segment': 'string',
    'income_segment_weight': 'double',
    'vector': 'string',
    'heuristic_common': 'any',
    'heuristic_segments': 'any',
    'longterm_interests': 'any',
    'raw_site_weights': 'any',
    'main_region_country': 'uint64',
    'region': 'string',
    'additional_features': 'any',
}

catboost_application_results_processing_query_template = """
$age_result = (
    SELECT
        Cast(PassThrough AS Uint64) AS yandexuid,
        $get_model_predictions($age_schema, Result) AS user_age_6s,
    FROM $age_processed
);
$gender_result = (
    SELECT
        Cast(PassThrough AS Uint64) AS yandexuid,
        $get_model_predictions($gender_schema, Result) AS gender,
    FROM $gender_processed
);
$income_result = (
    SELECT
        Cast(PassThrough AS Uint64) AS yandexuid,
        $get_model_predictions($income_schema, Result) AS income_5_segments,
    FROM $income_processed
);
INSERT INTO `{output_table}`
SELECT
    t1.yandexuid AS yandexuid,
    Yson::Serialize(Yson::FromDoubleDict(t1.user_age_6s)) AS user_age_6s,
    Yson::Serialize(Yson::FromDoubleDict(t2.gender)) AS gender,
    Yson::Serialize(Yson::FromDoubleDict(t3.income_5_segments)) AS income_5_segments,
    CAST({update_time} AS Uint64) AS update_time
FROM $age_result AS t1
INNER JOIN $gender_result AS t2
ON t1.yandexuid == t2.yandexuid
INNER JOIN $income_result AS t3
ON t1.yandexuid == t3.yandexuid
"""


@with_context
class MakeCatboostFeatures(object):
    def __init__(self, cat_features_dict, additional_features_number=0, default_additional_value=-1):
        self.cat_features_dict = cat_features_dict
        self.socdem_keys = socdem_config.yet_another_segment_names_by_label_type
        self.socdem_offset = len(self.socdem_keys['gender']) + len(self.socdem_keys['user_age_6s']) + \
            len(self.socdem_keys['income_5_segments'])
        self.additional_features_number = additional_features_number
        self.total_size = self.socdem_offset + len(self.cat_features_dict) + self.additional_features_number
        self.default_additional_value=default_additional_value

    def __call__(self, key, rows, context):
        features = None
        neuro_classification_row = None
        segments_count = 0
        for row in rows:
            if context.table_index == 0:
                neuro_classification_row = row
                features = np.zeros(self.total_size, dtype=np.float32)
                if self.additional_features_number > 0:
                    features[-self.additional_features_number:] = np.ones(self.additional_features_number) * \
                        self.default_additional_value
                offset = 0
                for socdem_segment in socdem_config.SOCDEM_SEGMENT_TYPE_NAMES:
                    for i, feature_name in enumerate(self.socdem_keys[socdem_segment]):
                        features[offset + i] = row[socdem_segment][feature_name]
                    offset += len(self.socdem_keys[socdem_segment])
            elif context.table_index == 1 and neuro_classification_row is not None:
                for cat_feature_type in ('heuristic_common', 'longterm_interests'):
                    cat_feature_ids = row[cat_feature_type]
                    if cat_feature_ids is not None:
                        segments_count += len(cat_feature_ids)
                        for cat_feature_id in cat_feature_ids:
                            feature_name = get_feature_name(
                                feature_type=categorical_feature_name_to_keyword[cat_feature_type],
                                value=cat_feature_id,
                            )
                            if feature_name in self.cat_features_dict:
                                features[self.socdem_offset + self.cat_features_dict[feature_name]] = 1
            elif context.table_index == 2 and neuro_classification_row is not None:
                # if there are few segments - neuro classification
                if segments_count < MIN_SEGMENTS_COUNT:
                    break

                raw_site_weights = row['raw_site_weights']
                if raw_site_weights is not None:
                    for site in raw_site_weights:
                        feature_name = get_feature_name(
                            feature_type=categorical_feature_name_to_keyword['raw_site_weights'],
                            value=site,
                        )
                        if feature_name in self.cat_features_dict:
                            features[self.socdem_offset + self.cat_features_dict[feature_name]] = raw_site_weights[site]
            elif context.table_index == 3 and neuro_classification_row is not None:
                if 'features' in row and row['features'] is not None:
                    features[-self.additional_features_number:] = row['features']
            else:
                return  # if there is no neuro classification - skip row

        if segments_count < MIN_SEGMENTS_COUNT:
            yield create_table_switch(raw_yandexuid_profiles_index)
            yield neuro_classification_row
        else:
            catboost_features_row = {
                'PassThrough': key['yandexuid'],
                'FloatFeatures': list(map(lambda elem: float(elem), features)),
                'CatFeatures': list(),
            }

            yield create_table_switch(catboost_features_index)
            yield catboost_features_row


class MakeCatboostTrainingFeatures(MakeCatboostTrainingFeaturesBase):
    def __init__(
        self,
        socdem_type,
        models_list,
        flat_features_dict,
        has_weights=True,
        additional_features_number=0,
        batch_size=4096,
    ):
        super(MakeCatboostTrainingFeatures, self).__init__(
            socdem_type=socdem_type,
            models_list=models_list,
            flat_features_dict=flat_features_dict,
            has_weights=has_weights,
            additional_features_number=additional_features_number,
            batch_size=batch_size,
        )
        self.additional_columns.append('yandexuid')

    def process_cat_features(self, row):
        features = np.zeros(len(self.flat_features_dict), dtype=np.float32)

        for cat_feature_type in cat_feature_types:
            cat_feature_ids = row[cat_feature_type]
            if cat_feature_ids is not None:
                for cat_feature_id in cat_feature_ids:
                    feature_name = get_feature_name(
                        feature_type=categorical_feature_name_to_keyword[cat_feature_type],
                        value=cat_feature_id,
                    )
                    if self.flat_features_dict.get(feature_name) is not None:
                        value = cat_feature_ids[cat_feature_id] if cat_feature_type == 'raw_site_weights' else 1
                        features[self.flat_features_dict[feature_name]] = value

        return features

    def get_id_with_id_type(self, additional_columns):
        return 'yandexuid_{}'.format(additional_columns['yandexuid'])
