import logging
import pandas as pd
from tqdm import tqdm
from yt.wrapper import ypath_join, TablePath
from datacloud.ml_utils.junk.benchmark_v2.transformers import (
    DecodeAndFillNanToMeanAndHit,
    OneHotTransformer,
    SelectColumnAndLogitAndFillNanToMeanAndHit,
    YuidDaysTransformer,
    PhoneWatchLogTransformer,
    SelectColumnsAndFillNanToMeanAndHit,
    SelectColumnsAndFillNan,
    ExpandListFeatureAndFillNanToMeanAndHit,
)
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import KFold
from sklearn.metrics import roc_auc_score
from catboost import CatBoostClassifier


def get_logger():
    logger = logging.getLogger(__name__)

    for handler in logger.handlers:
        logger.removeHandler(handler)
    if logger.parent:
        for handler in logger.parent.handlers:
            logger.parent.removeHandler(handler)

    syslog = logging.StreamHandler()
    filelog = logging.FileHandler('log')
    formatter = logging.Formatter('%(asctime)s %(message)s')
    syslog.setFormatter(formatter)
    filelog.setFormatter(formatter)
    logger.setLevel(logging.INFO)
    logger.addHandler(syslog)
    logger.addHandler(filelog)
    return logger


TRANSFORMERS = {
    'DecodeAndFillNanToMeanAndHit': DecodeAndFillNanToMeanAndHit,
    'OneHotTransformer': OneHotTransformer,
    'SelectColumnAndLogitAndFillNanToMeanAndHit': SelectColumnAndLogitAndFillNanToMeanAndHit,
    'YuidDaysTransformer': YuidDaysTransformer,
    'PhoneWatchLogTransformer': PhoneWatchLogTransformer,
    'SelectColumnsAndFillNanToMeanAndHit': SelectColumnsAndFillNanToMeanAndHit,
    'SelectColumnsAndFillNan': SelectColumnsAndFillNan,
    'ExpandListFeatureAndFillNanToMeanAndHit': ExpandListFeatureAndFillNanToMeanAndHit,
}


MODELS = {
    'LogisticRegression': LogisticRegression,
    'CatBoostClassifier': CatBoostClassifier
}


class Experiment(object):
    RANDOM_SEED = 42
    N_SPLIT = 5

    def __init__(self, features, model, click_stream_tables, root, yt_client,
                 target_column='target', nrows=None, logger=None, use_tqdm=False):
        self.features = features
        self.model = model
        self.click_stream_tables = click_stream_tables
        self.root = root
        self.yt_client = yt_client
        self.target_column = target_column
        self.nrows = nrows

        if self.nrows:
            self.tp_params = {'start_index': 0, 'end_index': self.nrows - 1}
        else:
            self.tp_params = {}

        self.use_tqdm = use_tqdm
        self.target_table = TablePath(
            ypath_join(self.root, 'raw_data/glued'),
            columns=['external_id', self.target_column],
            **self.tp_params
        )
        if logger is not None:
            self.logger = logger
        else:
            self.logger = get_logger()

    @classmethod
    def from_json(cls, config, root, **kwargs):
        features = [
            {
                'table': ypath_join(root, fdata['table']),
                'transformer': TRANSFORMERS[fdata['transformer']['class']](
                    **fdata['transformer'].get('params', {})
                )
            }
            for fdata in config['features']
        ]
        model = MODELS[config['model']['class']](**config['model'].get('params', {}))
        click_stream_tables = [ypath_join(root, table) for table in config['click_stream_tables']]
        return cls(features=features, model=model, click_stream_tables=click_stream_tables, root=root, **kwargs)

    def load_data(self):
        self.logger.info('start load data')
        target_rows = self.yt_client.read_table(self.target_table, enable_read_parallel=True)
        if self.use_tqdm:
            target_rows = tqdm(target_rows, total=self.nrows or self.yt_client.row_count(self.target_table))
        self.target = (
            pd.DataFrame(target_rows)
            .set_index('external_id', verify_integrity=True)
            .rename(columns={self.target_column: 'target'})
            .astype({'target': int})
            .query('target in [0, 1]')
            .sort_index()
        )
        self.logger.info('finish load target data')

        self.click_stream_data = None
        for table_path in self.click_stream_tables:
            rows = self.yt_client.read_table(
                self.yt_client.TablePath(table_path, columns=['external_id'], **self.tp_params),
                enable_read_parallel=True
            )
            if self.use_tqdm:
                rows = tqdm(rows, total=self.nrows or self.yt_client.row_count(table_path))
            if self.click_stream_data is None:
                self.click_stream_data = pd.DataFrame(rows).set_index('external_id', verify_integrity=True)
            else:
                self.click_stream_data = self.click_stream_data.join(
                    pd.DataFrame(rows)
                    .set_index('external_id', verify_integrity=True),
                    how='inner'
                )
        self.click_stream_data = self.click_stream_data.join(self.target[[]], how='inner')
        self.click_stream_data = (
            self.target[[]]
            .assign(cs=pd.Series(True, index=self.click_stream_data.index))
            ['cs']
            .fillna(False)
        )
        self.logger.info('finish load click stream data')

        for feature in self.features:
            rows = self.yt_client.read_table(
                self.yt_client.TablePath(feature['table'], **self.tp_params),
                enable_read_parallel=True
            )
            if self.use_tqdm:
                rows = tqdm(rows, total=self.nrows or self.yt_client.row_count(feature['table']))

            feature['data'] = (
                pd.DataFrame(rows)
                .set_index('external_id', verify_integrity=True)
                .pipe(lambda x: (
                    self.target[[]].join(x)
                ))
            )
            self.logger.info('finish load feature from %s', feature['table'])

    def prepare_features(self, ind, transformers=None):
        X = []

        features = self.features

        if transformers is not None:
            assert len(transformers) == len(self.features),\
                "transformers count should be equel count of features group"
            pairs = zip(self.features, transformers)
            if self.use_tqdm:
                pairs = tqdm(pairs, total=len(features))
            for feature, transformer in pairs:
                if self.use_tqdm:
                    pairs.set_description("Processing %s" % transformer.__class__.__name__)
                X.append(transformer.transform(feature['data'].iloc[ind]))
            X = pd.concat(X, axis=1)
            X.columns = list(range(len(X.columns)))
            return X

        transformers = []
        if self.use_tqdm:
            features = tqdm(features)
        for feature in features:
            transformer = feature['transformer']
            if self.use_tqdm:
                features.set_description("Processing %s" % transformer.__class__.__name__)
            transformer.fit(
                X=feature['data'].iloc[ind],
                y=self.target.iloc[ind]['target'],
            )
            X.append(transformer.transform(feature['data'].iloc[ind]))
            transformers.append(transformer)
        X = pd.concat(X, axis=1)
        X.columns = list(range(len(X.columns)))
        return X, transformers

    def run_experiment(self):
        self.load_data()
        cv = KFold(n_splits=self.N_SPLIT, shuffle=True, random_state=self.RANDOM_SEED)
        self.aucs = []
        self.aucs_on_cs = []

        self.logger.info('start fit models')
        for i, (train_ind, test_ind) in enumerate(cv.split(self.target)):
            X_train, transformers = self.prepare_features(train_ind, transformers=None)
            X_test = self.prepare_features(test_ind, transformers=transformers)
            self.logger.info('finish prepare features')
            y_train = self.target.iloc[train_ind]['target']
            y_test = self.target.iloc[test_ind]['target']

            self.model.fit(X_train, y_train)
            predict_on_test = self.model.predict_proba(X_test)[:, 1]
            auc = roc_auc_score(y_test, predict_on_test)

            cs = self.click_stream_data.iloc[test_ind]
            auc_on_cs = roc_auc_score(y_test[cs], predict_on_test[cs])
            self.logger.info('fit model on fold (%s / %s) with auc %s, on cs auc %s', i + 1, self.N_SPLIT, auc, auc_on_cs)
            self.aucs.append(auc)
            self.aucs_on_cs.append(auc_on_cs)
