import collections

import numpy as np
import pandas as pd
from scipy.special import logit
from sklearn.cluster import MiniBatchKMeans
from sklearn.preprocessing import MultiLabelBinarizer

from datacloud.dev_utils.data.data_utils import array_fromstring


class BaseTransformer(object):
    def _fit(self, X, y):
        raise NotImplementedError()

    @property
    def is_fitted(self):
        return getattr(self, '_is_fitted', False)

    def fit(self, X, y=None):
        self._fit(X=X, y=y)
        self._is_fitted = True
        return self

    def _transform(self, X):
        raise NotImplementedError()

    def transform(self, X):
        assert self.is_fitted, 'Should fit first'
        return self._transform(X)


class DecodeAndFillNanToMeanAndHit(BaseTransformer):
    def __init__(self, column='features'):
        self.column = column

    def _fit(self, X, y):
        self.means = (
            X[self.column]
            .loc[~X[self.column].isna()]
            .pipe(lambda x: pd.DataFrame(
                array_fromstring(b''.join(x.to_list()))
                .reshape(x.shape[0], -1),
                index=x.index
            ))
            .mean()
        )

    def _transform(self, X):
        return (
            X[self.column]
            .loc[~X[self.column].isna()]
            .pipe(lambda x: pd.DataFrame(
                array_fromstring(b''.join(x.to_list()))
                .reshape(x.shape[0], -1),
                index=x.index
            ))
            .pipe(lambda x: X[[]].join(x))
            .assign(hit_flg=lambda x: (~x.isna().min(axis=1)))
            .fillna(self.means)
        )


class OneHotTransformer(BaseTransformer):
    def __init__(self, columns):
        self.columns = columns

    def _fit(self, X, y):
        self.decoders = {}
        self.dimensions = {}
        for column in self.columns:
            values = sorted(X[column].unique())
            self.dimensions[column] = len(values)
            self.decoders[column] = {
                v: np.array([j == i for j in range(len(values))])
                for i, v in enumerate(values)
            }

    def _transform(self, X):
        one_hot_features = []
        for column in self.columns:
            one_hot_features.append(pd.DataFrame(np.array([
                self.decoders[column].get(v, np.repeat(False, self.dimensions[column]))
                for v in X[column]
            ]), index=X.index))
        return pd.concat(one_hot_features, axis=1)


class SelectColumnAndLogitAndFillNanToMeanAndHit(BaseTransformer):
    def __init__(self, column):
        self.column = column

    def _fit(self, X, y):
        self.mean = (
            X[self.column]
            .astype(float)
            .apply(logit)
            .mean()
        )

    def _transform(self, X):
        columns = [self.column] if not isinstance(self.column, list) else self.column
        return (
            X[columns]
            .astype(float)
            .apply(logit)
            .fillna(self.mean)
            .assign(hit_flg=~X[columns].isna().min(axis=1))
        )


class YuidDaysTransformer(BaseTransformer):
    MAX_DAYS = 540
    MIN_DAYS = 0

    def _fit(self, X, y):
        self.mean = (
            X['days_from_first_yuid']
            .dropna()
            .clip(lower=self.MIN_DAYS, upper=self.MAX_DAYS)
            .div(self.MAX_DAYS)
            .mean()
        )

    def _transform(self, X):
        return (
            X[['days_from_first_yuid']]
            .clip(lower=self.MIN_DAYS, upper=self.MAX_DAYS)
            .div(self.MAX_DAYS)
            .fillna(self.mean)
            .assign(hit_flg=~X['days_from_first_yuid'].isna())
        )


class PhoneWatchLogTransformer(BaseTransformer):
    def _fit(self, X, y):
        self.means = (
            X.filter(regex=r'^avito_category_\d{2}$')
            .clip(lower=0, upper=1)
            .mean()
        )

    def _transform(self, X):
        return (
            X.filter(regex=r'^avito_category_\d{2}$')
            .clip(lower=0, upper=1)
            .fillna(self.means)
            .assign(hit_flg=~X['avito_category_01'].isna())
        )


class SelectColumnsAndFillNanToMeanAndHit(BaseTransformer):
    def __init__(self, columns):
        self.columns = columns

    def _fit(self, X, y):
        self.means = (
            X[self.columns]
            .pipe(lambda x: x if isinstance(self.columns, list) else x.apply(pd.Series))
            .mean()
        )

    def _transform(self, X):
        return (
            X[self.columns]
            .pipe(lambda x: x if isinstance(self.columns, list) else x.apply(pd.Series))
            .assign(hit_flg=lambda x: (~x.isna()).max(axis=1))
            .fillna(self.means)
        )


class SelectColumnsAndFillNan(BaseTransformer):
    def __init__(self, fill_values):
        self.fill_values = fill_values
        self.columns = sorted(fill_values.keys())

    def _fit(self, X, y):
        return

    def _transform(self, X):
        return X[self.columns].fillna(self.fill_values)


class TopMultiLabelBinarizer(BaseTransformer):
    def __init__(self, col, n_to_save):
        self.col = col
        self.n_to_save = n_to_save
        self.multi_label_binarizer = MultiLabelBinarizer()

    def _select_top_objects(self, X):
        return X[self.col].map(lambda x: list(set(x) & self.top_set) if isinstance(x, list) else [])

    def _fit(self, X, y):
        all_objects = collections.Counter()
        X.loc[~X[self.col].isna(), self.col].map(lambda x: all_objects.update(x))
        self.top_set = set(x[0] for x in all_objects.most_common(self.n_to_save))

        self.multi_label_binarizer.fit(self._select_top_objects(X))

    def _transform(self, X):
        return self.multi_label_binarizer.transform(self._select_top_objects(X))


class TopKMeansOnMLB(BaseTransformer):
    def __init__(self, col, n_to_save, n_clusters, random_state=42):
        self.col = col
        self.n_to_save = n_to_save
        self.n_clusters = n_clusters
        self.random_state = random_state
        self.multi_label_binarizer = TopMultiLabelBinarizer(self.col, self.n_to_save)
        self.kmeans = MiniBatchKMeans(n_clusters=self.n_clusters, random_state=random_state)

    def _fit(self, X, y):
        self.multi_label_binarizer.fit(X, y)
        mlb_data = self.multi_label_binarizer.transform(X)
        self.kmeans.fit(mlb_data)

    def _transform(self, X):
        mlb_data = self.multi_label_binarizer.transform(X)
        return self.kmeans.transform(mlb_data)


def _zero():
    return 0


class LocationsTransformer(BaseTransformer):
    def __init__(self,
                 country_kmeans_top=60,
                 country_kmeans_clusters=20,
                 country_mlb_top=30,

                 region_kmeans_top=250,
                 region_kmeans_clusters=10,
                 region_mlb_top=50,

                 city_min_count=1000,
                 native_country_code=225):
        self.country_kmeans = TopKMeansOnMLB('country_ids', country_kmeans_top, country_kmeans_clusters)
        self.country_mlb = TopMultiLabelBinarizer('country_ids', country_mlb_top)

        self.region_kmeans = TopKMeansOnMLB('region_ids', region_kmeans_top, region_kmeans_clusters)
        self.region_mlb = TopMultiLabelBinarizer('region_ids', region_mlb_top)

        self.city_min_count = city_min_count
        self.native_country_code = native_country_code

    @staticmethod
    def get_transform_dict_from_list(ls, default=None):
        result = collections.defaultdict(default)
        for item in ls:
            result[item] = item
        return result

    @staticmethod
    def get_transform_dict(df, key, border, default=None):
        data = df.groupby(key).size().sort_values()
        data = data[data > border]
        return LocationsTransformer.get_transform_dict_from_list(data.index, default)

    @staticmethod
    def add_cols(df, data, name_pattern):
        cols = [name_pattern.format(i) for i in range(data.shape[1])]
        for col in cols:
            assert col not in df.columns, 'Column exists'
            df[col] = 0.0
        df.loc[:, cols] = data

    def was_abroad(self, countries_ids):
        """ Checks whether the user has been abroad.
        Algorithm takes into account the case when the user was outside of native_country_code
        but he was not marked inside native_country_code.

        Проверяет, был ли пользователь за границей.
        Код учитывает случай, когда пользователь был за пределами native_country_code,
        но не был отмечен внутри native_country_code.
        :param countries_ids: list of ids
        :return: True if user has been abroad, else False
        """
        if isinstance(countries_ids, list):
            return len(set(countries_ids + [self.native_country_code])) > 1
        return False

    def _fit(self, X, y):
        self.country_kmeans.fit(X, y)
        self.country_mlb.fit(X, y)
        self.region_kmeans.fit(X, y)
        self.region_mlb.fit(X, y)

        self.country_transformer = self.get_transform_dict_from_list(self.country_kmeans.multi_label_binarizer.top_set,
                                                                     _zero)
        self.region_transformer = self.get_transform_dict_from_list(self.region_kmeans.multi_label_binarizer.top_set,
                                                                    _zero)
        self.city_transformer = self.get_transform_dict(X, 'mode_city_id', self.city_min_count, _zero)
        self.city_type_transformer = self.get_transform_dict(X, 'mode_city_type', self.city_min_count, _zero)

    def _transform(self, X):
        X = X.copy()
        X['mode_country_id'] = X['mode_country_id'].map(self.country_transformer).astype(np.int64)
        X['mode_region_id'] = X['mode_region_id'].map(self.region_transformer).astype(np.int64)
        X['mode_city_id'] = X['mode_city_id'].map(self.city_transformer).astype(np.int64)
        X['mode_city_type'] = X['mode_city_type'].map(self.city_type_transformer).astype(np.int64)

        X['was_abroad'] = X['country_ids'].map(self.was_abroad).astype(np.uint8)

        self.add_cols(X, self.country_kmeans.transform(X), 'country_ids_kmeans_{}')
        self.add_cols(X, self.country_mlb.transform(X), 'country_ids_mlb_{}')
        self.add_cols(X, self.region_kmeans.transform(X), 'region_ids_kmeans_{}')
        self.add_cols(X, self.region_mlb.transform(X), 'region_ids_mlb_{}')

        X.drop(columns=['country_ids', 'region_ids'], inplace=True)
        return X.fillna(0)


TRANSFORMERS = {
    'DecodeAndFillNanToMeanAndHit': DecodeAndFillNanToMeanAndHit,
    'OneHotTransformer': OneHotTransformer,
    'SelectColumnAndLogitAndFillNanToMeanAndHit': SelectColumnAndLogitAndFillNanToMeanAndHit,
    'YuidDaysTransformer': YuidDaysTransformer,
    'PhoneWatchLogTransformer': PhoneWatchLogTransformer,
    'SelectColumnsAndFillNanToMeanAndHit': SelectColumnsAndFillNanToMeanAndHit,
    'SelectColumnsAndFillNan': SelectColumnsAndFillNan,
    'TopMultiLabelBinarizer': TopMultiLabelBinarizer,
    'TopKMeansOnMLB': TopKMeansOnMLB,
    'LocationsTransformer': LocationsTransformer,
}
