# -*- coding: utf-8 -*-
import inspect
import numpy as np
from sklearn.model_selection import GridSearchCV

from datacloud.dev_utils.logging.logger import get_basic_logger
from datacloud.dev_utils.json.json_utils import json_load_byteified
from datacloud.ml_utils.grid_search_wrapper.models_with_default_params import get_proba_logistic, get_ridge
from datacloud.ml_utils.grid_search_wrapper.cv_with_default_params import get_stratified_k_fold, get_k_fold

ESTIMATOR_TYPE2CLASS = {
    'logistic': get_proba_logistic,
    'ridge': get_ridge,
}

SCORING_TYPES = {
    'accuracy': 'accuracy',
    'roc_auc': 'roc_auc',
    'f1': 'f1',
}

CV_TYPES = {
    'StratifiedKFold': get_stratified_k_fold,
    'KFold': get_k_fold,
}


class GridSearchRunner:
    TYPE_KEY = 'type'
    DEFAULT_CV_TYPE = 'StratifiedKFold'

    def __init__(
            self,
            estimator='logistic',
            param_grid={'C': np.logspace(-3, 3, 21)},
            scoring='roc_auc',
            cv=5,
            logger=None,
            **kwargs):

        self.logger = logger or get_basic_logger(__name__)

        self.estimator = self.parse_estimator(estimator)
        self.param_grid = self.parse_param_grid(param_grid)
        self.scoring = self.parse_scoring(scoring)
        self.cv = self.parse_cv(cv)
        self.grid_search_kwargs = kwargs

        self.grid_search_cv = None

    def __str__(self):
        def make_params_str():
            args = inspect.getargspec(self.__init__)[0]
            args.remove('self')
            args = [(arg, getattr(self, arg)) for arg in args]
            args_strings = ['{}={}'.format(arg[0], arg[1]) for arg in args]
            return ', '.join(args_strings)

        return '{}({})'.format(
            self.__class__.__name__,
            make_params_str()
        )

    def __repr__(self):
        return self.__str__()

    @classmethod
    def from_params(cls, params):
        try:
            return cls(**params)
        except AssertionError as er:
            logger = get_basic_logger(name=__name__)
            logger.critical(er)
            return None

    @classmethod
    def from_json(cls, json_data):
        dict_params = json_load_byteified(json_data)
        return cls.from_params(dict_params)

    def parse_estimator(self, estimator):
        if isinstance(estimator, basestring):
            return ESTIMATOR_TYPE2CLASS[estimator]()

        elif isinstance(estimator, dict):
            est_type = estimator.pop(self.TYPE_KEY)
            assert est_type is not None, 'No estimator type given!'

            return ESTIMATOR_TYPE2CLASS[est_type](**estimator)

        raise AssertionError('Looking for `str` or `dict` in estimator!')

    def parse_param_grid(self, param_grid):
        assert isinstance(param_grid, (dict, list)), 'Looking for `dict` or `list` in param_grid!'
        return param_grid

    def parse_scoring(self, scoring):
        assert isinstance(scoring, basestring), 'Looking for `str` in scoring!'
        return SCORING_TYPES[scoring]

    def parse_cv(self, cv):
        if isinstance(cv, int):
            return CV_TYPES[self.DEFAULT_CV_TYPE](n_splits=cv)

        elif isinstance(cv, basestring):
            return CV_TYPES[cv]()

        elif isinstance(cv, dict):
            raise NotImplementedError('Dict params for CV are not implemented yet!')

        raise AssertionError('Looking for `int`, `str` or `dict` in cv!')

    def fit(self, X, y):
        self.logger.info('Starting Grid Search fit...')

        self.grid_search_cv = GridSearchCV(
            estimator=self.estimator,
            param_grid=self.param_grid,
            scoring=self.scoring,
            cv=self.cv,
            **self.grid_search_kwargs
        )
        self.grid_search_cv.fit(X, y)

        return self.grid_search_cv

    @property
    def cv_results_(self):
        assert self.grid_search_cv is not None, 'Fit grid search first!'
        return self.grid_search_cv.cv_results_

    @property
    def best_estimator_(self):
        assert self.grid_search_cv is not None, 'Fit grid search first!'
        return self.grid_search_cv.best_estimator_

    @property
    def best_score_(self):
        assert self.grid_search_cv is not None, 'Fit grid search first!'
        return self.grid_search_cv.best_score_

    @property
    def best_params_(self):
        assert self.grid_search_cv is not None, 'Run grid search first!'
        return self.grid_search_cv.best_params_
