import collections
from abc import (
    ABCMeta,
    abstractproperty,
)

import numpy as np

from crypta.lib.python.custom_ml import training_config
from crypta.lib.python.custom_ml.train_helpers.custom_model_train_helper import CustomModelTrainHelper


CustomRegressionParams = collections.namedtuple(
    'CustomRegressionParams',
    [
        'train_sample_path',
        'resource_type',
        'model_tag',
        'metrics_group_name',
        'model_description_in_sandbox',
        'make_train_sample_query',
    ],
)
CustomRegressionParams.__new__.func_defaults = (None,) * len(CustomRegressionParams._fields)


class CustomRegressionModelTrainHelper(CustomModelTrainHelper):
    __metaclass__ = ABCMeta

    @abstractproperty
    def model_params(self):
        """
        :return: CustomRegressionParams class
        """
        pass

    @property
    def graphite_group_name(self):
        return 'custom_regression'

    def get_metrics_to_send(self, best_validation_scores):
        return {
            'r2': best_validation_scores['R2'],
            'rmse': best_validation_scores['RMSE'],
            'train_sample_size': self.yt.row_count(self.model_params.train_sample_path),
        }

    def train_catboost_models(self, train_dataset, test_dataset, iterations=7500,  learning_rate=0.01):
        from catboost import CatBoostRegressor
        from catboost import Pool

        def get_catboost_model(x_train, x_test, y_train, y_test, weight_train=None):
            catboost_model = CatBoostRegressor(
                iterations=7500,
                learning_rate=0.01,
                loss_function='RMSE',
                custom_metric='R2',
                random_seed=42,
                od_type='Iter',
                od_wait=40,
            )
            catboost_model.fit(
                Pool(data=x_train, label=y_train, cat_features=self.feature_maker.cat_features_indexes),
                eval_set=(x_test, y_test),
                sample_weight=weight_train,
                use_best_model=True,
                verbose=50,
            )
            return catboost_model

        x_train, y_train, _ = train_dataset
        x_test, y_test, _ = test_dataset

        model = get_catboost_model(
            x_train=x_train, x_test=x_test, y_train=y_train, y_test=y_test,
        )

        self.logger.info('catboost model feature importances:')
        self.log_feature_importances(model)

        return model, self.get_metrics_to_send(model.get_best_score()['validation'])

    def get_model_and_metrics(self):
        features, columns = self.download_sample_with_vectors(
            source_table=self.model_params.train_sample_path,
            additional_columns={
                'target': np.float32,
                'yandexuid': np.uint64,
                'crypta_id': np.uint64,
            },
            max_sample_size=training_config.MAX_SAMPLE_SIZE,
        )

        datasets = self.get_train_test_datasets(
            features=features,
            columns=columns,
            target_columns='target',
            target_values_range=None,
        )
        train_dataset, test_dataset = datasets['train'], datasets['test']

        model, metrics_values = self.train_catboost_models(
            train_dataset, test_dataset,
        )

        metrics_to_send = []
        for metric_name, metric_value in metrics_values.items():
            metrics_to_send.append({
                'labels': {
                    'metric': metric_name,
                    'model': self.model_params.metrics_group_name,
                },
                'value': metric_value,
            })

        return model, metrics_to_send
