from abc import (
    ABCMeta,
    abstractmethod,
)
import json
import os
import tempfile

import numpy as np
import pandas as pd
from yt.wrapper import create_table_switch

from crypta.lib.python.custom_ml import training_config
from crypta.lib.python.custom_ml.tools.features import load_features_mappings
from crypta.profile.lib import vector_helpers
from crypta.profile.lib.socdem_helpers import socdem_config
from crypta.profile.utils import utils
from crypta.profile.utils.config import config


cat_feature_types = ('heuristic_common', 'longterm_interests', 'raw_site_weights')
cat_feature_mobile_types = ('model', 'manufacturer', 'main_region_obl', 'categories')

categorical_feature_name_to_keyword = {
    'heuristic_common': '547',
    'longterm_interests': '601',
    'raw_site_weights': 'site',
}


def get_feature_name(feature_type, value):
    return '{}_{}'.format(feature_type, value)


def get_weight_column(socdem_type):
    socdem_segment = socdem_config.socdem_type_to_segment_name[socdem_type]
    return '{}_weight'.format(socdem_segment)


def get_schema_dict_from_table(yt_client, table_path):
    schema_dict = {}
    for field_desc in yt_client.get_attribute(table_path, 'schema'):
        schema_dict[field_desc['name']] = field_desc['type']
    return schema_dict


def check_weight_column_in_table(yt_client, socdem_type, table_path):
    schema_dict = get_schema_dict_from_table(yt_client, table_path)
    weight_field = get_weight_column(socdem_type)
    if weight_field in schema_dict:
        return True
    return False


def download_cat_features_to_dict(yt, yt_folder_path, is_mobile=False):
    if is_mobile:
        cat_feature_keys = cat_feature_mobile_types
    else:
        cat_feature_keys = cat_feature_types

    cat_features_dicts, reverted_cat_features_dicts = load_features_mappings(yt, yt_folder_path, cat_feature_keys)

    if is_mobile:
        for cat_feature_type in cat_feature_keys:
            if cat_feature_type != 'categories':
                idx = len(cat_features_dicts[cat_feature_type])
                cat_features_dicts[cat_feature_type]['Other'] = idx
                reverted_cat_features_dicts[cat_feature_type][idx] = 'Other'

    return cat_features_dicts, reverted_cat_features_dicts


def download_features_dict_from_sandbox(resource_type, released, file_name, resource_id=None):
    with tempfile.NamedTemporaryFile() as dict_file_to_save:
        utils.download_file_from_sandbox(resource_type, released, file_name, dict_file_to_save.name, resource_id)
        with open(dict_file_to_save.name, 'r') as dict_file_to_load:
            cat_features_dict = json.load(dict_file_to_load)

    return cat_features_dict


class MakeCatboostTrainingFeaturesBase:
    __metaclass__ = ABCMeta
    """Get nn models predictions and cat features for catboost models training"""

    def __init__(self, socdem_type, models_list, flat_features_dict, has_weights=False,
                 additional_features_number=0, default_additional_value=-1, batch_size=4096):
        self.socdem_type = socdem_type
        self.socdem_segment = socdem_config.socdem_type_to_segment_name[socdem_type]
        self.models_list = models_list
        self.flat_features_dict = flat_features_dict
        self.has_weight = has_weights
        self.additional_columns = ['crypta_id', self.socdem_segment]
        if self.has_weight:
            self.additional_columns.append(get_weight_column(self.socdem_type))
        self.batch_size = batch_size
        self.batch = []
        self.additional_features_number = additional_features_number
        self.default_additional_value = default_additional_value

    def start(self):
        self.current_batch = []

    def __call__(self, row):
        self.current_batch.append(row)
        if len(self.current_batch) >= self.batch_size:
            for record in self.process_batch():
                yield record

    def finish(self):
        for record in self.process_batch():
            yield record

    @abstractmethod
    def process_cat_features(self, row):
        pass

    @abstractmethod
    def get_id_with_id_type(self, additional_columns):
        pass

    def process_batch(self):
        current_batch_size = len(self.current_batch)  # Final batch can be smaller

        vector_features = np.zeros((current_batch_size, socdem_config.VECTOR_SIZE), dtype=np.float32)
        categorical_features = np.zeros(
            (current_batch_size, len(self.flat_features_dict) + self.additional_features_number), dtype=np.float32,
        )
        additional_fields = []
        for idx, row in enumerate(self.current_batch):
            vector_features[idx, :socdem_config.VECTOR_SIZE] = vector_helpers.vector_row_to_features(row)

            categorical_features[idx, :len(self.flat_features_dict)] = self.process_cat_features(row)
            if self.additional_features_number > 0 and socdem_config.ADDITIONAL_FEATURES_COLUMN in row:
                if row[socdem_config.ADDITIONAL_FEATURES_COLUMN] is not None:
                    categorical_features[idx, len(self.flat_features_dict):] = row[socdem_config.ADDITIONAL_FEATURES_COLUMN]
                else:
                    categorical_features[idx, len(self.flat_features_dict):] = \
                        np.ones(self.additional_features_number) * self.default_additional_value

            additional_fields.append({field: row[field] for field in row if field in self.additional_columns})

        nn_predictions = [model.predict(vector_features) for model in self.models_list]
        nn_predictions = np.hstack(nn_predictions)

        for idx in range(current_batch_size):
            target_value = additional_fields[idx][self.socdem_segment]
            train_row = [str(socdem_config.segment_names_by_label_type[self.socdem_segment].index(target_value))]
            if self.has_weight:
                train_row.append(str(additional_fields[idx][get_weight_column(self.socdem_type)]))
            train_row.append('\t'.join(map(str, nn_predictions[idx])))
            train_row.append('\t'.join(map(str, categorical_features[idx])))

            output_row = {
                'key': self.get_id_with_id_type(additional_fields[idx]),
                'value': '\t'.join(train_row),
            }
            if additional_fields[idx]['crypta_id'] % training_config.VALIDATION_SAMPLE_PERCENTAGE \
                    != training_config.VALIDATION_SAMPLE_REST:
                yield create_table_switch(0)
            else:
                yield create_table_switch(1)
            yield output_row

        self.current_batch = []


def read_features_from_table_to_dict(yt, yt_folder_path, dict_name, key_column, value_column):
    features_dict = dict()
    reverted_features_dict = dict()
    for row in yt.read_table(os.path.join(yt_folder_path, dict_name)):
        features_dict[row[key_column]] = row[value_column]
        reverted_features_dict[value_column] = row[key_column]

    return features_dict, reverted_features_dict


def get_nn_output_features_description():
    socdem_features_description = []
    for socdem_segment_type in socdem_config.SOCDEM_SEGMENT_TYPES:
        for socdem_value in socdem_config.segment_names_by_label_type[socdem_segment_type]:
            socdem_features_description.append('{}:{}'.format(socdem_segment_type, socdem_value))

    return socdem_features_description


def copy_features_dicts(yt_client, is_mobile, source_dir, destination_dir, transaction=None):
    with yt_client.Transaction(transaction_id=transaction.transaction_id if transaction is not None else None):
        if is_mobile:
            features = list(cat_feature_mobile_types) + ['region_matching']
        else:
            features = cat_feature_types

        for feature in features:
            yt_client.copy(
                os.path.join(source_dir, feature),
                os.path.join(destination_dir, feature),
                force=True,
            )


def get_additional_features_description(yt_client, yt_folder_path):
    additional_features_dict, _ = read_features_from_table_to_dict(
        yt=yt_client,
        yt_folder_path=os.path.dirname(yt_folder_path),
        dict_name='additional_features_list',
        key_column='feature_index',
        value_column='feature_name',
    )

    return [additional_features_dict[feature_index] for feature_index in range(len(additional_features_dict))]


def get_features_description(yt_client, features_dict):
    """
    Get list of features extensive description for socdem models.
    """
    reverted_features_dict = {features_dict[feature]: feature for feature in features_dict}
    features_description = [reverted_features_dict[idx] for idx in range(len(reverted_features_dict))]

    lab_segments_df = pd.DataFrame(list(yt_client.read_table(yt_client.TablePath(
        config.LAB_SEGMENTS_INFO_TABLE,
        columns=('exportKeywordId', 'exportSegmentId', 'name'),
    ))))

    features_description_extended = []
    for idx in range(len(reverted_features_dict)):
        if features_description[idx].startswith('site_'):
            features_description_extended.append(features_description[idx])
        else:
            keyword, segment_id = features_description[idx].split('_')
            segment_name_row = lab_segments_df[
                (lab_segments_df['exportKeywordId'] == int(keyword)) &
                (lab_segments_df['exportSegmentId'] == int(segment_id))]['name']
            segment_name = segment_name_row.values[0] if len(segment_name_row) > 0 else ''
            features_description_extended.append('{}:{}:{}'.format(keyword, segment_id, segment_name))

    assert len(features_description_extended) == len(features_dict), 'Some features are not in the description'

    return features_description_extended


def get_mobile_features_description(yt_client, features_dict):
    """
    Get list of features extensive description for mobile socdem models.
    """
    reverted_features_dict = {features_dict[feature]: feature for feature in features_dict}
    features_description = [reverted_features_dict[idx] for idx in range(len(reverted_features_dict))]
    region_matching, _ = read_features_from_table_to_dict(
        yt=yt_client,
        yt_folder_path=config.PRESTABLE_CATEGORICAL_FEATURES_MATCHING_DIR,
        dict_name='region_matching',
        key_column='feature',
        value_column='en_name',
    )

    features_description_extended = []
    for idx in range(len(reverted_features_dict)):
        if features_description[idx].startswith('main_region_obl_'):
            region_id = features_description[idx].split('_')[-1]
            region_name = region_id if region_id == 'Other' else region_matching.get(int(region_id), 'Other')
            features_description_extended.append('region:{}'.format(region_name))
        else:
            features_description_extended.append(features_description[idx])

    assert len(features_description_extended) == len(features_dict), 'Some features are not in the description'

    return features_description_extended
