#!/usr/bin/env python
# -*- coding: utf-8 -*-

import functools
import logging

from crypta.lib.python.logging import logging_helpers
from crypta.lookalike.lib.python.utils import utils
from crypta.lookalike.services.training.lib import (
    ads_goals_segments,
    data_for_prism_scores,
    dssm_features_to_val_segments,
    dssm_lal_scores,
    dssm_to_segments,
    dssm_to_users,
    dssm_tp_fp_tn_fn,
    experiment_metrics,
    features_for_evaluation,
    lal_metrics,
    negatives,
    pr_auc,
    random_tp_fp_tn_fn,
    segments,
    segments_dssm_features,
    segments_user_data_stats,
    train_sample,
    train_validation_split,
    user_data_yandexuids,
    users_dssm_features,
    validation_labels_to_users,
)

import nirvana.job_context as nv


logger = logging.getLogger(__name__)


def main():
    logging_helpers.configure_stderr_logger(logging.getLogger(), level=logging.INFO)
    logger.info('LaL training')

    context = nv.context()
    parameters = context.parameters
    inputs = context.get_inputs()
    outputs = context.get_outputs()
    result_output = outputs.get('result')

    if parameters is not None and 'working-dir' in parameters.keys():
        utils.update_config_for_experiment(parameters)

    tasks_dict = {
        'add_dssm_features_to_val_segments': functools.partial(dssm_features_to_val_segments.add, nv_params=parameters),
        'add_val_labels_to_users': functools.partial(validation_labels_to_users.add, nv_params=parameters),
        'add_features_for_evaluation': functools.partial(features_for_evaluation.add,
                                                         nv_params=parameters, output=result_output),
        'apply_dssm_to_segments': functools.partial(dssm_to_segments.apply,
                                                    nv_params=parameters, inputs=inputs),
        'apply_dssm_to_users': functools.partial(dssm_to_users.apply,
                                                 nv_params=parameters, inputs=inputs),
        'calculate_dssm_lal_scores': functools.partial(dssm_lal_scores.calculate, nv_params=parameters),
        'calculate_dssm_tp_fp_tn_fn': functools.partial(dssm_tp_fp_tn_fn.calculate, nv_params=parameters),
        'calculate_dssm_pr_auc': functools.partial(pr_auc.calculate_dssm, nv_params=parameters),
        'calculate_experiment_metrics_ci': functools.partial(experiment_metrics.calculate_ci, nv_params=parameters),
        'calculate_lal_metrics': functools.partial(lal_metrics.calculate, nv_params=parameters, inputs=inputs),
        'calculate_prism_scores': functools.partial(experiment_metrics.calculate_prism_scores, nv_params=parameters),
        'calculate_random_tp_fp_tn_fn': functools.partial(random_tp_fp_tn_fn.calculate, nv_params=parameters),
        'calculate_random_pr_auc': functools.partial(pr_auc.calculate_random, nv_params=parameters),
        'check_correlation': functools.partial(lal_metrics.check_correlation, nv_params=parameters),
        'get_segments_dssm_features': functools.partial(segments_dssm_features.get,
                                                        nv_params=parameters, output=result_output),
        'get_segments_user_data_stats': functools.partial(segments_user_data_stats.get, nv_params=parameters),
        'get_users_dssm_features': functools.partial(users_dssm_features.get,
                                                     nv_params=parameters, output=result_output),
        'make_train_sample': functools.partial(train_sample.make, nv_params=parameters, inputs=inputs),
        'make_negatives': functools.partial(negatives.make, nv_params=parameters),
        'prepare_data_for_prism_scores': functools.partial(data_for_prism_scores.prepare, nv_params=parameters),
        'select_segments': functools.partial(segments.select, nv_params=parameters),
        'shuffle_user_data_yandexuids': functools.partial(user_data_yandexuids.shuffle, nv_params=parameters),
        'train_validation_split': functools.partial(train_validation_split.split, nv_params=parameters),
        'update_ads_goals_segments': functools.partial(ads_goals_segments.update, nv_params=parameters),
    }

    job_name = parameters.get('job_name')
    logger.info('Job name: {}'.format(job_name))

    if job_name in tasks_dict:
        tasks_dict[job_name]()
    else:
        logger.warn('Unknown job_name "{}"'.format(job_name))
        exit(1)


if __name__ == '__main__':
    main()
