import pandas as pd
import copy
from collections import namedtuple
from tqdm import tqdm
from jsonschema import Draft7Validator, validators
from collections import OrderedDict
import json
import os

from sklearn.model_selection import KFold

from yt.wrapper import ypath_join, TablePath
from startrek_client import Startrek

import vh

from datacloud.dev_utils.yt import yt_utils
from datacloud.dev_utils.logging.logger import get_basic_logger
from datacloud.dev_utils.svnversion import svnversion
from datacloud.key_manager.key_helpers import get_key

from datacloud.ml_utils.benchmark_v2.tables_cache import (
    YtTablesCacher,
    YtAliasesStorage
)
from datacloud.ml_utils.benchmark_v2.exps_storage import YtExpsStorage
from datacloud.ml_utils.benchmark_v2.transformers import TRANSFORMERS
from datacloud.ml_utils.benchmark_v2.metrics import name2metric, default_metrics
from datacloud.ml_utils.benchmark_v2.models import name2model, clf_model, regr_model


def extend_with_default(validator_class):
    """
    From https://python-jsonschema.readthedocs.io/en/stable/faq/
    """
    validate_properties = validator_class.VALIDATORS['properties']

    def set_defaults(validator, properties, instance, schema):
        for property, subschema in properties.items():
            if 'default' in subschema:
                instance.setdefault(property, subschema['default'])

        for error in validate_properties(
            validator, properties, instance, schema,
        ):
            yield error

    return validators.extend(
        validator_class, {'properties': set_defaults},
    )


FeatureParams = namedtuple('FeatureParams', ['table_path', 'transformer'])


class ExperimentConfig(object):
    ABS_PATH_PREF = '//'
    REL_PATH_PREF = '/'

    SCHEMA = {
        'type': 'object',
        'properties': {
            'root': {'type': 'string'},
            'ticket': {'type': 'string'},
            'revision': {
                'type': 'number',
                'default': svnversion.version()
            },
            'cs_tables': {
                'type': 'array',
                'items': {'type': 'string'},
                'minItems': 1,
                'uniqueItems': True
            },
            'target': {
                'type': 'object',
                'properties': {
                    'table_path': {'type': 'string'},
                    'target_col': {
                        'type': 'string',
                        'default': 'target'
                    },
                },
                'required': ['table_path'],
                'additionalProperties': False
            },
            'features': {
                'type': 'array',
                'items': {
                    'type': 'object',
                    'properties': {
                        'table_path': {'type': 'string'},
                        'transformer': {'type': 'string'},
                        'params': {
                            'type': 'object',
                            'default': {}
                        }
                    },
                    'required': ['table_path', 'transformer'],
                    'additionalProperties': False
                },
                'minItems': 1,
                'uniqueItems': True
            },
            'model': {
                'type': 'object',
                'properties': {
                    'model': {'type': 'string'},
                    'params': {
                        'type': 'object',
                        'default': {}
                    }
                },
                'required': ['model'],
                'additionalProperties': False
            },
            'metrics': {
                'type': 'array',
                'items': {'type': 'string'},
                'minItems': 1,
                'uniqueItems': True,
                'default': default_metrics
            }
        },
        'required': ['ticket', 'cs_tables', 'target', 'features', 'model'],
        'additionalProperties': False
    }

    def __init__(self, config, yt_client):
        self._yt_client = yt_client
        self._tables_cacher = YtTablesCacher(yt_client=self._yt_client)
        self._aliases_storage = YtAliasesStorage(yt_client=self._yt_client)
        self._config = config
        self._extended_config = self.extend_config(config)

    def resolve_path(self, table_path):
        if isinstance(table_path, TablePath):
            table_path = table_path.to_yson_string()
        if table_path.startswith(self.ABS_PATH_PREF):
            return self._tables_cacher.cache_and_translate(table_path)
        elif table_path.startswith(self.REL_PATH_PREF):
            assert self.root is not None, 'You should specify root!'
            table_path = ypath_join(self.root, table_path)
            return self._tables_cacher.cache_and_translate(table_path=table_path)

        aliases_storage_rec = self._aliases_storage.translate(table_path)
        assert aliases_storage_rec is not None, 'Unknown alias {}'.format(table_path)
        cache_key = aliases_storage_rec.cache_key

        return self._tables_cacher.read_cache(cache_key)

    def resolve_all_paths(self, config):
        res_config = copy.deepcopy(config)
        res_config['cs_tables'] = [self.resolve_path(t) for t in res_config['cs_tables']]
        res_config['target']['table_path'] = self.resolve_path(res_config['target']['table_path'])
        for feature in res_config['features']:
            feature['table_path'] = self.resolve_path(feature['table_path'])

        return res_config

    def extend_config(self, config):
        schema_validator = extend_with_default(Draft7Validator)(self.SCHEMA)
        schema_validator.validate(config)

        self.root = config.get('root')
        extended_config = self.resolve_all_paths(config)
        extended_config.pop('root', None)

        assert extended_config['revision'] == svnversion.version(), \
               'Config revision differs from binary file revision'

        model_name = extended_config['model']['model']
        assert model_name in name2model, 'Unknown model {}'.format(model_name)

        for metric_name in extended_config['metrics']:
            assert metric_name in name2metric, 'Unknown metric {}'.format(metric_name)

        return extended_config

    def parse_feature(self, feature_params):
        transformer_cls = TRANSFORMERS[feature_params['transformer']]
        transformer_params = feature_params['params']
        return FeatureParams(
            table_path=self.resolve_path(feature_params['table_path']),
            transformer=transformer_cls(**transformer_params)
        )

    @property
    def parsed_features(self):
        return [self.parse_feature(fp)
                for fp in self.extended_config['features']]

    def make_model(self):
        model_cfg = self.extended_config['model']
        model_cls, model_type = name2model[model_cfg['model']]
        model_inst = model_cls(**model_cfg['params'])

        return model_type, model_inst

    @property
    def extended_config(self):
        return self._extended_config

    @property
    def config(self):
        return self._config

    @property
    def mcalculators(self):
        return [
            name2metric[metric_name]()
            for metric_name in self.extended_config['metrics']
        ]


class DataLoader(object):
    def __init__(self, yt_client, logger=None, use_tqdm=True, index_col='external_id'):
        self._yt_client = yt_client
        self._use_tqdm = use_tqdm
        self._index_col = index_col
        self._path2df = {}
        self._logger = logger or get_basic_logger(__name__)

    def load_table(self, table_path):
        table_path = TablePath(table_path).to_yson_string()
        if table_path in self._path2df:
            return self._path2df[table_path]
        rows = self._yt_client.read_table(table_path, enable_read_parallel=True)
        if self._use_tqdm:
            rows = tqdm(rows, total=self._yt_client.row_count(table_path))

        self._logger.info('Loading %s', table_path)
        df = pd.DataFrame(rows).set_index(self._index_col, verify_integrity=True)
        self._path2df[table_path] = df
        return df

    def load_feature(self, table_path, target_df):
        df = self.load_table(table_path)
        return target_df[[]].join(df)

    def load_target(self, table_path, target_col_name):
        table_path = TablePath(table_path, columns=[self._index_col, target_col_name])
        return (
            self.load_table(table_path)
            .rename(columns={target_col_name: 'target'})
            .astype({'target': int})
            .query('target in [0, 1]')
            .sort_index()
        )

    def load_cs_index(self, tables_paths, target_df):
        cs_df = None
        for table_path in tables_paths:
            table_path = TablePath(table_path, columns=[self._index_col])
            if cs_df is None:
                cs_df = self.load_table(table_path)
            else:
                cs_df = cs_df.join(self.load_table(table_path), how='inner')

        cs_df = cs_df.join(target_df[[]], how='inner')
        return (
            target_df[[]]
            .assign(cs=pd.Series(True, index=cs_df.index))
            ['cs']
            .fillna(False)
        )


@vh.lazy(
    object,
    config=vh.mkoption(str),
    yt_token=vh.mkoption(vh.Secret),
    st_token=vh.mkoption(vh.Secret)
)
def run_experimnet_vh(config, yt_token, st_token):
    os.environ['YT_TOKEN'] = yt_token.value
    os.environ['ST_TOKEN'] = st_token.value
    experiment = Experiment(json.loads(config))
    return experiment.run_experimnet()


class Experiment(object):
    WORKFLOW_GUID = '12042491-fb32-4a9e-8ff0-d03e87ec9c82'

    def __init__(self, config):
        self._yt_client = yt_utils.get_yt_client()
        self._logger = get_basic_logger(__name__)
        self._config = ExperimentConfig(config=config, yt_client=self._yt_client)
        self._data_loader = DataLoader(yt_client=self._yt_client, logger=self._logger)
        self._exps_storage = YtExpsStorage(yt_client=self._yt_client)

    @property
    def config(self):
        return self._config.extended_config

    def __getitem__(self, name):
        return self.config[name]

    def _generate_folds(self, y):
        kf = KFold(n_splits=5, shuffle=True, random_state=42)
        return kf.split(y)

    def _prepare_features(self, train_ind, test_ind, target_df):
        def concat_features(dfs):
            X = pd.concat(dfs, axis=1)
            X.columns = list(range(len(X.columns)))
            return X

        X_train, X_test = [], []
        for table_path, transformer in self._config.parsed_features:
            feature_raw = self._data_loader.load_feature(
                table_path=table_path,
                target_df=target_df
            )
            transformer.fit(feature_raw.iloc[train_ind], target_df.iloc[train_ind]['target'])
            features_transformed = transformer.transform(feature_raw)
            X_train.append(features_transformed.iloc[train_ind])
            X_test.append(features_transformed.iloc[test_ind])

        return concat_features(X_train), concat_features(X_test)

    def _write_startrek(self, exp_rec):
        if not os.environ.get('ST_TOKEN'):
            self._logger.warn('No ST_TOKEN provided, skipping write ST message step')
            return

        client = Startrek(
            useragent='robot-xprod',
            base_url='https://st-api.yandex-team.ru/v2',
            token=os.environ['ST_TOKEN']
        )
        issue = client.issues[self.config['ticket']]

        header = 'Experiment **{}**'.format(exp_rec.key)
        metric_stirngs = ['||{}|{}||'.format(k, round(v, 5))
                          for k, v in exp_rec.metrics.items()]
        metrics = '#|\n{}\n|#'.format('\n'.join(metric_stirngs))
        config = '<{{config\n%%(json)\n{}\n%%\n}}>'.format(
            json.dumps(exp_rec.config, indent=4, sort_keys=True)
        )

        issue.comments.create(text='\n'.join((header, metrics, config)))

    def _make_metrics(self, y_true_folds, y_score_folds, is_cs):
        metrics = [mcalculator.calculate(
            y_true_folds=y_true_folds,
            y_score_folds=y_score_folds,
            is_cs=is_cs
        ) for mcalculator in self._config.mcalculators]

        return [(m.name, m.value) for m in metrics]

    def run_experimnet(self):
        self._logger.info('Started experiment')

        exps_storage_rec = self._exps_storage.read_by_config(self.config)
        if exps_storage_rec is not None:
            self._logger.info('Config was found. Results fast forwarded')
            return exps_storage_rec

        self._logger.info('Loading target...')
        target_df = self._data_loader.load_target(
            table_path=self['target']['table_path'],
            target_col_name=self['target']['target_col']
        )

        self._logger.info('Loading clickstream tables')
        cs_df = self._data_loader.load_cs_index(
            tables_paths=self['cs_tables'],
            target_df=target_df
        )

        test_targets, test_preds = [], []
        test_targets_cs, test_preds_cs = [], []
        for i, (train_ind, test_ind) in enumerate(self._generate_folds(target_df.target), 1):
            X_train, X_test = self._prepare_features(
                train_ind=train_ind,
                test_ind=test_ind,
                target_df=target_df
            )
            y_train = target_df.iloc[train_ind]['target']
            y_test = target_df.iloc[test_ind]['target']

            model_type, model_inst = self._config.make_model()
            self._logger.info('Fitting model at fold %d', i)
            model_inst.fit(X_train, y_train)
            self._logger.info('Making predictions at fold %d', i)

            if model_type == clf_model:
                pred = model_inst.predict_proba(X_test)[:, 1]
            elif model_type == regr_model:
                pred = model_inst.predict(X_test)
            else:
                raise NameError('Unknown model type {}'.format(model_type))

            test_targets.append(y_test)
            test_preds.append(pred)

            cs = cs_df.iloc[test_ind]
            test_targets_cs.append(y_test[cs])
            test_preds_cs.append(pred[cs])

        self._logger.info('Calculating metrics')
        metrics = OrderedDict(
            self._make_metrics(
                y_true_folds=test_targets,
                y_score_folds=test_preds,
                is_cs=False
            ) +
            self._make_metrics(
                y_true_folds=test_targets_cs,
                y_score_folds=test_preds_cs,
                is_cs=True
            )
        )

        exp_rec = self._exps_storage.add_experiment(
            config=self.config,
            metrics=metrics
        )
        self._logger.info('Cached experimnet with key %s', str(exp_rec.key))
        self._logger.info(exp_rec.metrics)
        self._write_startrek(exp_rec)

        return exp_rec

    def _was_started(self):
        return hasattr(self, '_graph_keeper')

    @property
    def graph_keeper(self):
        if not self._was_started():
            raise RuntimeError('Run graph first!')
        return self._graph_keeper

    def is_done(self):
        return self.graph_keeper.get_total_completion_future().done()

    @property
    def wrokflow_info(self):
        if not hasattr(self, '_wrokflow_info'):
            self._wrokflow_info = self.graph_keeper.get_workflow_info()
        return self._wrokflow_info

    @property
    def results(self):
        if not hasattr(self, '_results'):
            self._results = self.graph_keeper.download(self._result_file)
        return self._results

    def run(self, backend='nirvana', hardware_params=None):
        assert backend in ['nirvana', 'local'], 'Unknown backend {}'.format(backend)
        backend = vh.NirvanaBackend() if backend == 'nirvana' else vh.LocalBackend()
        with vh.Graph() as graph, backend:
            self._result_file = run_experimnet_vh(
                config=json.dumps(self.config),
                yt_token=vh.get_yt_token_secret(),
                st_token=vh.Secret('robot_xprod_st_token')
            )
            hardware_params = hardware_params or vh.HardwareParams(
                max_ram=64 * 1024,
                cpu_guarantee=800
            )
            self._result_file.set_hardware(hardware_params)
            self._graph_keeper = vh.run_async(
                graph,
                oauth_token=get_key('pipeline_secrets', 'NIRVANA_TOKEN'),
                quota='datacloud',
                workflow_guid=self.WORKFLOW_GUID,
                label=self._exps_storage.get_key(self.config),
                yt_token_secret='robot_xprod_yt_token',
                start=True,
                secrets={
                    'robot_xprod_st_token': get_key('pipeline_secrets', 'ST_TOKEN')
                },
                backend=backend
            )
            self._logger.info('Building your graph...')
