# -*- coding: utf-8 -*-
import tempfile
import numpy as np
import pandas as pd
from sklearn.externals import joblib
from yt.wrapper import ypath_join

from datacloud.dev_utils.logging.logger import get_basic_logger
from datacloud.dev_utils.yt.yt_utils import get_yt_client, create_folders
from datacloud.ml_utils.crossval_utils.OOF_predictions import make_OOF_prediction
from datacloud.ml_utils.grid_search_wrapper.grid_search_runner import GridSearchRunner
from datacloud.ml_utils.grid_search_wrapper.nirvana_cube.load_table import (
    load_table_from_yt, DEFAULT_FEATURES_COL, DEFAULT_TARGET_COL, DEFAULT_IDS_COL
)
from datacloud.ml_utils.common.constants import RANDOM_SEED


def get_file_name(ticket_name, target_name, extension):
    return '{ticket_name}-{target_name}.{extension}'.format(
        ticket_name=ticket_name,
        target_name=target_name,
        extension=extension
    )


def make_test_predictions(estimator, X_test):
    test_pred = np.empty(0)
    if X_test.shape[0] != 0:
        if hasattr(estimator, 'predict_proba'):
            test_pred = estimator.predict_proba(X_test)[:, 1]
        elif hasattr(estimator, 'predict'):
            test_pred = estimator.predict(X_test)
        else:
            raise AttributeError('Estimator has no either predict_proba or predict method!')

    return test_pred


def make_predcitions_df(oof_ids, ids_test, oof_preds, test_pred, oof_flags):
    test_flags = np.full(ids_test.shape[0], -1, dtype=int)

    pred_df = pd.DataFrame({
        'external_id': np.concatenate((oof_ids, ids_test)),
        'prediction': np.concatenate((oof_preds, test_pred)),
        'flag': np.concatenate((oof_flags, test_flags)).astype(int)
    })
    pred_df = pred_df.groupby('external_id').max().reset_index()

    return pred_df


def dump_and_stream_estim(yt_client, estim, yt_folder, ticket_name, target_name):
    with tempfile.NamedTemporaryFile() as fp:
        joblib.dump(estim, fp)
        fp.seek(0)

        pkl_name = get_file_name(ticket_name, target_name, 'pkl')
        yt_client.write_file(ypath_join(yt_folder, pkl_name), fp)


def dump_and_stream_df(yt_client, pred_df, yt_folder, ticket_name, target_name):
    with tempfile.NamedTemporaryFile() as fp:
        pred_df.to_csv(fp, index=False)
        fp.seek(0)

        pkl_name = get_file_name(ticket_name, target_name, 'csv')
        yt_client.write_file(ypath_join(yt_folder, pkl_name), fp)


def fix_external_id_in_df(df, eid_col):
    fixed_df = df.copy()
    fixed_df[eid_col] = fixed_df[eid_col].apply(lambda eid: eid.rsplit('_', 1)[0])
    return fixed_df


def assert_params(params):
    assert 'table_path' in params, 'table_path should be specified'

    if params.get('stream_to_yt', False):
        should_pass = ('ticket_name', 'yt_folder')
        assert_message = '{} should be specified when using stream_to_yt'.format(
            should_pass
        )
        assert all(param in params for param in should_pass), assert_message


def make_and_stream_preds(yt_client, estimator, load_res, yt_folder, ticket_name,
                          target_name, write_fixed_csv, external_id_column,
                          oof_ids=None, oof_preds=None, oof_flags=None):

    oof_ids = oof_ids if oof_ids is not None else []
    oof_preds = oof_preds if oof_preds is not None else []
    oof_flags = oof_flags if oof_flags is not None else []

    test_pred = make_test_predictions(estimator, load_res.X_test)
    pred_df = make_predcitions_df(oof_ids=oof_ids, ids_test=load_res.ids_test,
                                  oof_preds=oof_preds, test_pred=test_pred,
                                  oof_flags=oof_flags)

    create_folders(yt_client=yt_client, folders=[yt_folder])
    dump_and_stream_estim(yt_client, estimator, yt_folder, ticket_name, target_name)
    dump_and_stream_df(yt_client, pred_df, yt_folder, ticket_name, target_name)

    if write_fixed_csv:
        fixed_eid_df = fix_external_id_in_df(pred_df, external_id_column)
        dump_and_stream_df(yt_client, fixed_eid_df, yt_folder, ticket_name, target_name + '_fixed')


def run_cube_with_params(params, logger):
    logger.info(params)
    assert_params(params)
    yt_client = get_yt_client()

    target_name = params.get('target_name', DEFAULT_TARGET_COL)
    external_id_column = params.get('external_id', DEFAULT_IDS_COL)

    logger.info('Loading table from yt...')
    load_res = load_table_from_yt(
        yt_client=yt_client,
        table_path=params['table_path'],
        target_name=target_name,
        features_column=params.get('features_column', DEFAULT_FEATURES_COL),
        external_id_column=external_id_column,
        verbose=params.get('load_verbose', 100000),
        logger=logger,
        max_rows=params.get('max_rows')
    )
    logger.info('Table loaded!')
    logger.info('X.shape:{}, X_test.shape{}'.format(load_res.X.shape, load_res.X_test.shape))

    grid_search_runner = GridSearchRunner.from_params(params.get('grid_search_params', {}))
    logger.info(grid_search_runner)
    grid_search_runner.fit(load_res.X, load_res.y)

    best_estimator = grid_search_runner.best_estimator_
    logger.info('Best estimator: {}'.format(best_estimator))

    logger.info('Preparing OOF prediction...')
    oof_preds, oof_ids, oof_flags, fold_results = make_OOF_prediction(
        clf=best_estimator,
        X=load_res.X,
        y=load_res.y,
        ids=load_res.ids,
        random_state=RANDOM_SEED
    )

    if params.get('stream_to_yt', False):
        logger.info('Fitting best estimator on whole dataset...')
        best_estimator.fit(load_res.X, load_res.y)

        make_and_stream_preds(
            yt_client=yt_client,
            estimator=best_estimator,
            load_res=load_res,
            yt_folder=params['yt_folder'],
            ticket_name=params['ticket_name'],
            target_name=target_name,
            write_fixed_csv=params.get('write_fixed_csv', True),
            external_id_column=external_id_column,
            oof_ids=oof_ids,
            oof_preds=oof_preds,
            oof_flags=oof_flags
        )

    logger.info('Best score is {}'.format(grid_search_runner.best_score_))
    logger.info('Best params are {}'.format(grid_search_runner.best_params_))

    logger.info('mean_train_score {}'.format(grid_search_runner.cv_results_['mean_train_score']))
    logger.info('std_train_score {}'.format(grid_search_runner.cv_results_['std_train_score']))

    logger.info('mean_test_score {}'.format(grid_search_runner.cv_results_['mean_test_score']))
    logger.info('std_test_score {}'.format(grid_search_runner.cv_results_['std_test_score']))

    return {
        'target_name': target_name,
        'fold_results': fold_results,

        'best_score': grid_search_runner.best_score_,
        'best_params': grid_search_runner.best_params_,

        'mean_train_score': list(grid_search_runner.cv_results_['mean_train_score']),
        'std_train_score': list(grid_search_runner.cv_results_['std_train_score']),
        'mean_test_score': list(grid_search_runner.cv_results_['mean_test_score']),
        'std_test_score': list(grid_search_runner.cv_results_['std_test_score']),

        'mean_train': np.mean(fold_results['train_auc']),
        'std_train': np.std(fold_results['train_auc']),
        'mean_val': np.mean(fold_results['val_auc']),
        'std_val': np.std(fold_results['val_auc']),
    }


def main():
    logger = get_basic_logger(__name__)
    params = {
        'table_path': '//projects/scoring/homecredit/XPROD-1091/features_prod',
        'target_name': 'target_DEF_6_60',
        'stream_to_yt': True,
        'yt_folder': '//projects/scoring/dev/penguin-diver',
        'ticket_name': 'XPROD-000',
        'max_rows': 1000,
    }
    result = run_cube_with_params(params, logger)
    logger.info(result)


if __name__ == '__main__':
    main()
