from abc import (
    ABCMeta,
    abstractproperty,
)
import collections

from cached_property import cached_property
import numpy as np

from crypta.lib.python.custom_ml import training_config
from crypta.lib.python.custom_ml.tools.metrics import (
    calculate_metrics,
    calculate_roc_auc_scores,
)
from crypta.lib.python.custom_ml.train_helpers.custom_model_train_helper import CustomModelTrainHelper


CustomClassificationParams = collections.namedtuple(
    'CustomClassificationParams',
    [
        'train_sample_path',
        'resource_type',
        'model_tag',
        'metrics_group_name',
        'model_description_in_sandbox',
        'segment_id_to_name',
        'segment_name_to_id',
        'ordered_thresholds',
        'make_train_sample_query',
    ],
)


class CustomClassificationModelTrainHelper(CustomModelTrainHelper):
    __metaclass__ = ABCMeta
    segment_id_to_name = collections.OrderedDict([
        (0, 'negative'),
        (1, 'positive'),
    ])

    @abstractproperty
    def model_params(self):
        """
        :return: CustomClassificationParams class
        """
        pass

    @cached_property
    def graphite_group_name(self):
        return 'custom_classification'

    @staticmethod
    def get_segment_name_to_id(ordered_thresholds):
        return dict([(segment_name, segment_id) for segment_id, segment_name in enumerate(ordered_thresholds)])

    def train_catboost_models(self, train_dataset, test_dataset, iterations=7500,  learning_rate=0.01):
        from catboost import CatBoostClassifier

        def get_catboost_model(x_train, x_test, y_train, y_test, weight_train=None):
            catboost_model = CatBoostClassifier(
                iterations=iterations,
                learning_rate=learning_rate,
                loss_function='MultiClass',
                random_seed=42,
                od_type='Iter',
                od_wait=40,
            )
            catboost_model.fit(
                x_train, y_train,
                cat_features=self.feature_maker.cat_features_indexes,
                eval_set=(x_test, y_test),
                sample_weight=weight_train,
                use_best_model=True,
                verbose=50,
            )
            y_predicted_train = catboost_model.predict_proba(x_train)
            y_predicted_test = catboost_model.predict_proba(x_test)
            return catboost_model, y_predicted_train, y_predicted_test

        x_train, y_train, _ = train_dataset
        x_test, y_test, _ = test_dataset

        model, y_pred_train, y_pred_test = get_catboost_model(
            x_train=x_train, x_test=x_test, y_train=np.argmax(y_train, axis=1), y_test=np.argmax(y_test, axis=1),
        )

        self.logger.info('catboost model feature importances:')
        self.log_feature_importances(model)

        if self.model_params.ordered_thresholds is None:
            ordered_thresholds = [0.5] * len(self.model_params.segment_name_to_id)
        elif type(self.model_params.ordered_thresholds) is collections.OrderedDict:
            ordered_thresholds = self.model_params.ordered_thresholds.values()
        else:
            ordered_thresholds = self.model_params.ordered_thresholds

        catboost_ml_metrics = calculate_metrics(
            target_type=self.model_params.metrics_group_name,
            y_predicted_train=y_pred_train,
            y_train=y_train,
            y_predicted_test=y_pred_test,
            y_test=y_test,
            ordered_thresholds=ordered_thresholds,
            logger=self.logger,
        )
        catboost_ml_metrics += calculate_roc_auc_scores(
            y_test,
            y_pred_test,
            self.model_params.metrics_group_name,
            self.model_params.segment_id_to_name,
        )

        return model, catboost_ml_metrics

    def get_model_and_metrics(self):
        features, columns = self.download_sample_with_vectors(
            source_table=self.model_params.train_sample_path,
            additional_columns={
                'segment_name': np.float32,
                'yandexuid': np.uint64,
                'crypta_id': np.uint64,
            },
            max_sample_size=training_config.MAX_SAMPLE_SIZE,
        )

        target_counter = collections.Counter(columns['segment_name'])
        metrics_to_send = []
        for segment, value in target_counter.items():
            metrics_to_send.append({
                'labels': {
                    'metric': 'train_distribution',
                    'model': self.model_params.metrics_group_name,
                    'segment': self.model_params.segment_id_to_name[segment],
                },
                'value': float(value) / len(columns['segment_name']),
            })

        datasets = self.get_train_test_datasets(
            features=features,
            columns=columns,
            target_columns='segment_name',
            target_values_range=range(len(self.model_params.segment_name_to_id)),
        )
        train_dataset, test_dataset = datasets['train'], datasets['test']

        model, metrics = self.train_catboost_models(
            train_dataset, test_dataset,
        )
        metrics_to_send += metrics
        metrics_to_send.append({
            'labels': {
                'metric': 'train_sample_size',
                'model': self.model_params.metrics_group_name,
            },
            'value': self.yt.row_count(self.model_params.train_sample_path),
        })

        return model, metrics_to_send
