import itertools
import json
import os
import tempfile

import numpy as np
from yt.yson import get_bytes

from crypta.lib.proto.user_data import user_data_pb2
from crypta.lib.python.custom_ml import training_config
from crypta.lib.python.custom_ml.tools.utils import get_secrets
from crypta.lib.python.sandbox.client import get_sandbox_client
from crypta.profile.lib.bb_helpers import keyword_name_to_bb_keyword_id


def get_feature_name(first_name, last_name):
    return '{}_{}'.format(first_name, last_name)


def normalize(x):
    return x / np.sqrt(np.dot(x, x))


def load_features_mappings(yt, yt_folder_path, feature_types):
    features_dicts = dict()
    reverted_features_dicts = dict()

    for feature_type in feature_types:
        features_dicts[feature_type] = dict()
        reverted_features_dicts[feature_type] = dict()
        if yt.exists(os.path.join(yt_folder_path, feature_type)):
            for row in yt.read_table(os.path.join(yt_folder_path, feature_type)):
                features_dicts[feature_type][row['feature']] = row['feature_index'] - 1
                reverted_features_dicts[feature_type][row['feature_index'] - 1] = row['feature']
        else:
            features_dicts[feature_type][feature_type] = 0
            reverted_features_dicts[feature_type][0] = feature_type

    return features_dicts, reverted_features_dicts


def get_features_dict_unraveled(yt, segment_feature_types, vector_size):
    offset = 0
    features_order = sorted(segment_feature_types) + ['vector']
    cat_features_dicts, _ = load_features_mappings(
        yt=yt,
        yt_folder_path=training_config.CATEGORICAL_FEATURES_CUSTOM_ML_MATCHING_DIR,
        feature_types=segment_feature_types,
    )
    cat_features_dict_unraveled = {}

    for feature_group in features_order:
        if feature_group == 'vector':
            feature_group_size = vector_size
        else:
            feature_group_size = len(cat_features_dicts[feature_group])
            for key in cat_features_dicts[feature_group]:
                cat_features_dict_unraveled[get_feature_name(
                    first_name=keyword_name_to_bb_keyword_id[feature_group],
                    last_name=key
                )] = offset + cat_features_dicts[feature_group][key]
        offset += feature_group_size

    return cat_features_dict_unraveled, cat_features_dicts


def download_features_dict_from_sandbox(resource_type, released, file_name):
    with tempfile.NamedTemporaryFile() as dict_file_to_save:
        client = get_sandbox_client(
            os.getenv('SANDBOX_TOKEN') or get_secrets().get_secret('ROBOT_CRYPTA_SANDBOX_OAUTH')
        )
        resource_id = client._get_last_released_resource_id_for_status(resource_type, released)
        client.load_resource(resource_id, dict_file_to_save.name, resource_path=file_name)

        with open(dict_file_to_save.name, 'r') as dict_file_to_load:
            cat_features_dict = json.load(dict_file_to_load)

    return cat_features_dict


class MakeCatboostFeatures(object):
    def __init__(self, yt, segment_feature_types, vector_size=training_config.VECTOR_SIZE, pass_through='yuid',
                 training_mode=True):
        self.vector_size = vector_size
        self.pass_through = pass_through
        self.segment_feature_types = segment_feature_types
        self.features_order = sorted(self.segment_feature_types) + ['vector']

        if training_mode:
            self.cat_features_dict_unraveled, self.cat_features_dicts = get_features_dict_unraveled(
                yt=yt,
                segment_feature_types=self.segment_feature_types,
                vector_size=self.vector_size,
            )
        else:
            self.cat_features_dict_unraveled = download_features_dict_from_sandbox(
                resource_type='CRYPTA_FEATURES_MAPPING_FOR_TRAINABLE_SEGMENTS',
                released='stable',
                file_name='',
            )

        self.n_features = len(self.cat_features_dict_unraveled) + self.vector_size
        self.cat_features_indexes = []
        for cat_feature in ('174_gender', '175_age_segments', '614_income_5_segments'):
            self.cat_features_indexes.append(self.cat_features_dict_unraveled[cat_feature])

    def get_feature_row(self, row):
        self.user_data.Vectors.ParseFromString(get_bytes(row['Vectors']))
        attributes = row['Attributes']
        if attributes is not None:
            self.user_data.Attributes.ParseFromString(get_bytes(attributes))
        segments = row['Segments']
        if segments is not None:
            self.user_data.Segments.ParseFromString(get_bytes(segments))

        features = [0] * (
            self.n_features - self.vector_size) + list(normalize(self.user_data.Vectors.Vector.Data))

        features[self.cat_features_dict_unraveled['174_gender']] = int(self.user_data.Attributes.Gender)
        features[self.cat_features_dict_unraveled['175_age_segments']] = int(self.user_data.Attributes.Age)
        features[self.cat_features_dict_unraveled['614_income_5_segments']] = int(self.user_data.Attributes.Income)

        for segment_info in self.user_data.Segments.Segment:
            feature_name = get_feature_name(first_name=segment_info.Keyword, last_name=segment_info.ID)
            if feature_name in self.cat_features_dict_unraveled:
                features[self.cat_features_dict_unraveled[feature_name]] = 1

        return features

    def start(self):
        self.user_data = user_data_pb2.TUserData()

    def __call__(self, row):
        features = self.get_feature_row(row)
        mask = np.ones(self.n_features, np.bool)
        mask[self.cat_features_indexes] = False

        float_features = list(map(lambda x: float(x), itertools.compress(features, mask)))
        cat_features = list(map(lambda x: str(x), itertools.compress(features, ~mask)))

        yield {
            'PassThrough': row[self.pass_through],
            'FloatFeatures': float_features,
            'CatFeatures': cat_features,
        }
