import numpy as np
import pandas as pd
from scipy.special import logit
from datacloud.dev_utils.data.data_utils import array_fromstring


class BasicTransformer(object):

    def _fit(self, X, y):
        pass

    def _transform(self, X):
        raise NotImplementedError()

    def fit(self, X, y=None):
        self._fit(X, y=y)
        self.is_fitted = True
        return self

    def transform(self, X):
        assert hasattr(self, 'is_fitted') and self.is_fitted, 'Should fit first'
        transformed = self._transform(X)
        # assert transformed.index.equals(X.index), 'Indexes should be equels'
        # assert not transformed.isna().max().max(), 'Transformed features contains nans'
        return transformed


class DecodeAndFillNanToMeanAndHit(BasicTransformer):

    def __init__(self, column='features'):
        self.column = column

    def _fit(self, X, y):
        self.means = (
            X[self.column]
            .loc[~X[self.column].isna()]
            .pipe(lambda x: pd.DataFrame(
                array_fromstring(b''.join(x.to_list()))
                .reshape(x.shape[0], -1),
                index=x.index
            ))
            .mean()
        )

    def _transform(self, X):
        return (
            X[self.column]
            .loc[~X[self.column].isna()]
            .pipe(lambda x: pd.DataFrame(
                array_fromstring(b''.join(x.to_list()))
                .reshape(x.shape[0], -1),
                index=x.index
            ))
            .pipe(lambda x: X[[]].join(x))
            .assign(hit_flg=lambda x: (~x.isna().min(axis=1)))
            .fillna(self.means)
        )


class OneHotTransformer(BasicTransformer):

    def __init__(self, columns):
        self.columns = columns

    def _fit(self, X, y):
        self.decoders = {}
        self.dimensions = {}
        for column in self.columns:
            values = sorted(X[column].unique())
            self.dimensions[column] = len(values)
            self.decoders[column] = {
                v: np.array([j == i for j in range(len(values))])
                for i, v in enumerate(values)
            }
        return self

    def _transform(self, X):
        one_hot_features = []
        for column in self.columns:
            one_hot_features.append(pd.DataFrame(np.array([
                self.decoders[column].get(v, np.repeat(False, self.dimensions[column]))
                for v in X[column]
            ]), index=X.index))
        return pd.concat(one_hot_features, axis=1)


class SelectColumnAndLogitAndFillNanToMeanAndHit(BasicTransformer):

    def __init__(self, column):
        self.column = column

    def _fit(self, X, y):
        self.mean = (
            X[self.column]
            .astype(float)
            .dropna()
            .apply(logit)
            .mean()
        )

    def _transform(self, X):
        return (
            X[[self.column]]
            .astype(float)
            .apply(logit)
            .fillna(self.mean)
            .assign(hit_flg=~X[self.column].isna())
        )


class YuidDaysTransformer(BasicTransformer):
    MAX_DAYS = 540
    MIN_DAYS = 0

    def _fit(self, X, y):
        self.mean = (
            X['days_from_first_yuid']
            .dropna()
            .clip(lower=self.MIN_DAYS, upper=self.MAX_DAYS)
            .div(self.MAX_DAYS)
            .mean()
        )

    def _transform(self, X):
        return (
            X[['days_from_first_yuid']]
            .clip(lower=self.MIN_DAYS, upper=self.MAX_DAYS)
            .div(self.MAX_DAYS)
            .fillna(self.mean)
            .assign(hit_flg=~X['days_from_first_yuid'].isna())
        )


class PhoneWatchLogTransformer(BasicTransformer):

    def _fit(self, X, y):
        self.means = (
            X.filter(regex=r'^avito_category_\d{2}$')
            .clip(lower=0, upper=1)
            .mean()
        )

    def _transform(self, X):
        return (
            X.filter(regex=r'^avito_category_\d{2}$')
            .clip(lower=0, upper=1)
            .fillna(self.means)
            .assign(hit_flg=~X['avito_category_01'].isna())
        )


class SelectColumnsAndFillNanToMeanAndHit(BasicTransformer):

    def __init__(self, columns):
        self.columns = columns

    def _fit(self, X, y):
        self.means = (
            X[self.columns]
            .mean()
        )

    def _transform(self, X):
        return (
            X[self.columns]
            .fillna(self.means)
            .assign(hit_flg=(~X[self.columns].isna()).max(axis=1))
        )


class SelectColumnsAndFillNan(BasicTransformer):

    def __init__(self, fill_values):
        self.fill_values = fill_values
        self.columns = sorted(fill_values.keys())
        self.is_fitted = True

    def _transform(self, X):
        return X[self.columns].fillna(self.fill_values)


class ExpandListFeatureAndFillNanToMeanAndHit(BasicTransformer):

    def __init__(self, column):
        self.column = column

    def _fit(self, X, y):
        self.means = (
            X[self.column]
            .dropna()
            .apply(pd.Series)
            .mean()
        )

    def _transform(self, X):
        return (
            X[self.column]
            .dropna()
            .apply(pd.Series)
            .astype(float)
            .pipe(lambda x: X[[]].join(x))
            .fillna(self.means)
            .assign(hit_flg=~X[self.column].isna())
        )
