from abc import (
    ABCMeta,
    abstractmethod,
    abstractproperty,
)

import pandas as pd

from crypta.lib.python.custom_ml import training_config
from crypta.lib.python.custom_ml.tools.features import MakeCatboostFeatures
from crypta.lib.python.custom_ml.train_helpers.base_train_helper import BaseModelTrainHelper
from crypta.profile.lib.bb_helpers import keyword_name_to_bb_keyword_id


class CustomModelTrainHelper(BaseModelTrainHelper):
    __metaclass__ = ABCMeta

    def __init__(self, yt, logger, date, vector_size=training_config.VECTOR_SIZE):
        super(CustomModelTrainHelper, self).__init__(yt=yt, logger=logger, vector_size=vector_size)
        self.date = date
        self.feature_maker = MakeCatboostFeatures(
            yt=self.yt,
            segment_feature_types=training_config.segment_feature_types,
            vector_size=self.vector_size,
        )
        self.feature_maker.start()

    @abstractproperty
    def graphite_group_name(self):
        """
        :return: 'custom_classification' or 'custom_regression'
        """
        pass

    @abstractproperty
    def model_params(self):
        """
        :return: CustomRegressionParams or CustomClassificationParams class
        """
        pass

    @abstractmethod
    def train_catboost_models(self, *args, **kwargs):
        """
        :return: catboost_model and catboost_metrics
        """
        pass

    @abstractmethod
    def get_model_and_metrics(self):
        """
        Builds train sample, train model and calculate metrics
        :return: model and metrics
        """
        pass

    def download_sample_with_vectors(self, source_table, additional_columns, max_sample_size=None):
        """
        Download sample_with_vectors
        additional_columns: dict with name and data type
        """
        source_table = self.yt.TablePath(
            source_table,
            end_index=max_sample_size,
            columns=training_config.user_data_features + list(additional_columns.keys()),
        )

        if max_sample_size is not None:
            n_records = min(self.yt.row_count(source_table), max_sample_size)
        else:
            n_records = self.yt.row_count(source_table)

        features = [[0] * self.feature_maker.n_features for _ in range(n_records)]

        self.logger.info('Columns to read: {}'.format(
            training_config.user_data_features + list(additional_columns.keys()))
        )
        self.logger.info('Downloading {} records from {}'.format(n_records, source_table))
        self.logger.info('number of features: {}'.format(self.feature_maker.n_features))

        additional_columns = self.init_additional_columns(additional_columns=additional_columns, n_records=n_records)

        for i, row in enumerate(self.yt.read_table(source_table)):
            if not (i % round(float(n_records) * 0.05)):
                self.logger.info('%.1f%%' % (100 * (i + 1) / float(n_records)))

            features[i][:] = self.feature_maker.get_feature_row(row)

            for column in additional_columns:
                if column == 'segment_name':
                    additional_columns[column][i] = self.model_params.segment_name_to_id[row[column]]
                else:
                    additional_columns[column][i] = row[column]

        self.logger.info('Downloading finished!')

        return features, additional_columns

    def get_feature_importances_df(self, model):
        feature_importances = model.feature_importances_

        features_to_dataframe = []
        offset = 0
        for feature_group in self.feature_maker.features_order:
            if feature_group == 'vector':
                features = dict(zip(range(self.vector_size), range(self.vector_size)))
            else:
                features = self.feature_maker.cat_features_dicts[feature_group]

            for segment_id in features:
                features_to_dataframe.append(
                    {
                        'keyword': feature_group,
                        'keyword_id':
                            str(keyword_name_to_bb_keyword_id[feature_group])
                            if feature_group in keyword_name_to_bb_keyword_id else
                            None,
                        'segment_id': str(segment_id),
                        'feature_importance': feature_importances[offset + features[segment_id]],
                    }
                )

            offset += len(features)

        lab_segments_df = self.get_lab_segments_df()

        features_df_named = pd.merge(
            pd.DataFrame(features_to_dataframe),
            lab_segments_df,
            how='left',
            left_on=['keyword_id', 'segment_id'],
            right_on=['exportKeywordId', 'exportSegmentId'],
        )

        feature_importances_df = features_df_named[[
            'name', 'keyword', 'keyword_id', 'segment_id', 'feature_importance']
        ].rename(
            columns={'name': 'segment_name'},
        ).sort_values(
            by='feature_importance',
            ascending=False,
        ).reset_index(
            drop=True,
        )

        return feature_importances_df

    @staticmethod
    def get_metrics_group_name(model_name):
        return '{}_segments'.format(model_name)

    def log_feature_importances(self, model):
        feature_importances_df = self.get_feature_importances_df(model)
        self.logger.info(feature_importances_df.to_string())

    def get_top_features(self, model, number_of_features=20):
        feature_importances_df = self.get_feature_importances_df(model)
        top_features = []

        for _, row in feature_importances_df.iterrows():
            if not pd.isnull(row['segment_name']):
                top_features.append(row['segment_name'])
            else:
                if row['keyword'] != 'vector':
                    top_features.append(row['keyword'])

        self.logger.info('\n'.join(top_features[:number_of_features]))
        return top_features[:number_of_features]
