import numpy as np

from yt.wrapper import with_context

from crypta.profile.lib.socdem_helpers import socdem_config
from crypta.profile.lib.socdem_helpers.tools.features import (
    cat_feature_mobile_types,
    get_feature_name,
    MakeCatboostTrainingFeaturesBase,
)

weighted_mobile_training_sample_schema = {
    'id': 'string',
    'id_type': 'string',
    'crypta_id': 'uint64',
    'gender': 'string',
    'age_segment': 'string',
    'income_segment': 'string',
    'income_segment_weight': 'double',
    'vector': 'string',
    'model': 'string',
    'manufacturer': 'string',
    'main_region_obl': 'uint64',
    'categories': 'any',
    'additional_features': 'any',
}

catboost_application_result_processing_query_template = """
$age_result = (
    SELECT
        ListHead(Yson::ConvertToStringList(PassThrough)) AS id,
        ListLast(Yson::ConvertToStringList(PassThrough)) AS id_type,
        $get_model_predictions($age_schema, Result) AS user_age_6s,
    FROM $age_processed
);
$gender_result = (
    SELECT
        ListHead(Yson::ConvertToStringList(PassThrough)) AS id,
        ListLast(Yson::ConvertToStringList(PassThrough)) AS id_type,
        $get_model_predictions($gender_schema, Result) AS gender,
    FROM $gender_processed
);
$income_result = (
    SELECT
        ListHead(Yson::ConvertToStringList(PassThrough)) AS id,
        ListLast(Yson::ConvertToStringList(PassThrough)) AS id_type,
        $get_model_predictions($income_schema, Result) AS income_5_segments,
    FROM $income_processed
);
INSERT INTO `{output_table}`
SELECT
    t1.id AS id,
    t1.id_type AS id_type,
    Yson::Serialize(Yson::FromDoubleDict(t1.user_age_6s)) AS user_age_6s,
    Yson::Serialize(Yson::FromDoubleDict(t2.gender)) AS gender,
    Yson::Serialize(Yson::FromDoubleDict(t3.income_5_segments)) AS income_5_segments,
    CAST({update_time} AS Uint64) AS update_time
FROM $age_result AS t1
INNER JOIN $gender_result AS t2
ON t1.id == t2.id AND t1.id_type == t2.id_type
INNER JOIN $income_result AS t3
ON t1.id == t3.id AND t1.id_type == t3.id_type
"""


@with_context
class MakeCatboostFeatures(object):
    def __init__(self, cat_features_dict, additional_features_number=0, default_additional_value=-1):
        self.cat_features_dict = cat_features_dict
        self.socdem_keys = socdem_config.yet_another_segment_names_by_label_type
        self.socdem_offset = len(self.socdem_keys['gender']) + len(self.socdem_keys['user_age_6s']) + \
            len(self.socdem_keys['income_5_segments'])
        self.additional_features_number = additional_features_number
        self.total_size = self.socdem_offset + len(self.cat_features_dict) + self.additional_features_number
        self.default_additional_value = default_additional_value

    def __call__(self, key, rows, context):
        features = None
        has_neuro_classification_row = False
        for row in rows:
            if context.table_index == 0:
                has_neuro_classification_row = True
                features = np.zeros(self.total_size, dtype=np.float32)
                if self.additional_features_number > 0:
                    features[-self.additional_features_number:] = np.ones(self.additional_features_number) * \
                        self.default_additional_value
                offset = 0
                for socdem_segment in socdem_config.SOCDEM_SEGMENT_TYPE_NAMES:
                    for i, feature_name in enumerate(self.socdem_keys[socdem_segment]):
                        features[offset + i] = row[socdem_segment][feature_name]
                    offset += len(self.socdem_keys[socdem_segment])
            elif context.table_index == 1 and has_neuro_classification_row:
                for cat_feature_type in cat_feature_mobile_types:
                    cat_feature_value = row[cat_feature_type]
                    if cat_feature_value is not None:
                        if cat_feature_type == 'categories':
                            for category, value in cat_feature_value.items():
                                feature_name = get_feature_name(feature_type=cat_feature_type, value=category)
                                features[self.socdem_offset + self.cat_features_dict[feature_name]] = value
                        else:
                            feature_name = get_feature_name(feature_type=cat_feature_type, value=cat_feature_value)
                            if feature_name not in self.cat_features_dict:
                                feature_name = get_feature_name(feature_type=cat_feature_type, value='Other')
                            features[self.socdem_offset + self.cat_features_dict[feature_name]] = 1
            elif context.table_index == 2 and has_neuro_classification_row:
                if 'features' in row and row['features'] is not None:
                    features[-self.additional_features_number:] = row['features']
            else:
                return

        catboost_features_row = {
            'PassThrough': [key['id'], key['id_type']],
            'FloatFeatures': list(map(lambda elem: float(elem), features)),
            'CatFeatures': list(),
        }
        yield catboost_features_row


class MakeCatboostTrainingFeatures(MakeCatboostTrainingFeaturesBase):
    def __init__(
        self,
        socdem_type,
        models_list,
        flat_features_dict,
        has_weights=False,
        additional_features_number=0,
        batch_size=4096,
    ):
        super(MakeCatboostTrainingFeatures, self).__init__(
            socdem_type=socdem_type,
            models_list=models_list,
            flat_features_dict=flat_features_dict,
            has_weights=has_weights,
            additional_features_number=additional_features_number,
            batch_size=batch_size,
        )
        self.additional_columns.extend(['id', 'id_type'])

    def process_cat_features(self, row):
        features = np.zeros(len(self.flat_features_dict), dtype=np.float32)

        for cat_feature_type in cat_feature_mobile_types:
            cat_feature_value = row[cat_feature_type]
            if cat_feature_value is not None:
                if cat_feature_type == 'categories':
                    for category, value in cat_feature_value.items():
                        feature_name = get_feature_name(feature_type=cat_feature_type, value=category)
                        features[self.flat_features_dict[feature_name]] = value
                else:
                    feature_name = get_feature_name(feature_type=cat_feature_type, value=cat_feature_value)
                    if feature_name not in self.flat_features_dict:
                        feature_name = get_feature_name(feature_type=cat_feature_type, value='Other')
                    features[self.flat_features_dict[feature_name]] = 1

        return features

    def get_id_with_id_type(self, additional_columns):
        return '{}_{}'.format(additional_columns['id_type'], additional_columns['id'])
