import functools
import logging
import os

import nirvana.job_context as nv

from crypta.lib.python.logging import logging_helpers
from crypta.lib.python.proto_secrets import proto_secrets
from crypta.lib.python.yql import yql_helpers
from crypta.lib.python.yt import yt_helpers
from crypta.profile.lib.socdem_helpers.tools import features
from crypta.profile.lib.socdem_helpers.train_utils import training
from crypta.profile.services.train_socdem_models.lib.common import utils
from crypta.profile.services.train_socdem_models.lib.socdem_model import (
    categorical_features_matching,
    common_sample,
)
from crypta.profile.services.train_socdem_models.lib.mobile_socdem_model import (
    categorical_features_matching as mobile_categorical_features_matching,
    common_sample as mobile_common_sample,
)
from crypta.profile.utils.config import config


logger = logging.getLogger(__name__)


def main():
    logging_helpers.configure_stderr_logger(logging.getLogger(), level=logging.INFO)
    logger.info('Socdem models training')

    if os.environ.get('environment') != 'local_testing':
        context = nv.context()
        parameters = context.get_parameters()
        nn_models, dict_input = utils.get_inputs(context)
        nn_model_output_file, dict_output_file = utils.get_outputs(context)
        use_thresholds = True
    else:
        parameters = os.environ.copy()
        dict_output_file, nn_model_output_file = None, None
        nn_models, dict_input = None, None
        use_thresholds = False

    config_for_training = utils.get_proto_config(parameters)
    logger.info('Config:\n{}'.format(proto_secrets.get_copy_without_secrets(config_for_training)))

    yt_client = yt_helpers.get_yt_client(
        yt_proxy=config_for_training.Yt.Proxy,
        yt_pool=config_for_training.Yt.Pool,
        yt_token=config_for_training.Yt.Token,
        remote_temp_tables_directory=config.PROFILES_TMP_YT_DIRECTORY,
    )
    yql_client = yql_helpers.create_yql_client(
        yt_proxy=config_for_training.Yt.Proxy,
        token=config_for_training.Yql.Token,
        pool=config_for_training.Yt.Pool,
        tmp_folder=config.PROFILES_YQL_TMP_YT_DIRECTORY,
    )

    tasks_dict = {
        'calculate_and_send_metrics': functools.partial(
            utils.wrap_with_nirvana_transaction(yt_client, training.calculate_and_send_metrics),
            yt_client=yt_client,
            socdem_type=config_for_training.SocdemType,
            table_path=config_for_training.PathsInfo.PredictionsBySocdem,
            is_mobile=config_for_training.IsMobile,
            send_metrics=config_for_training.MetricsInfo.SendMetrics,
            use_thresholds=use_thresholds,
        ),
        'copy_features_dicts': functools.partial(
            utils.wrap_with_nirvana_transaction(yt_client, features.copy_features_dicts),
            yt_client=yt_client,
            is_mobile=config_for_training.IsMobile,
            source_dir=config.PRESTABLE_CATEGORICAL_FEATURES_MATCHING_DIR,
            destination_dir=config.CATEGORICAL_FEATURES_MATCHING_DIR,
        ),
        'get_catboost_sample': functools.partial(
            utils.wrap_with_nirvana_transaction(yt_client, training.get_catboost_tables_for_training),
            yt_client=yt_client,
            nn_models=nn_models,
            features_dict=dict_input,
            config_for_training=config_for_training,
        ),
        'get_categorical_features_matching': functools.partial(
            mobile_categorical_features_matching.get if config_for_training.IsMobile else categorical_features_matching.get,
            yt_client=yt_client,
            yql_client=yql_client,
            date=config_for_training.Date,
            output_dict=dict_output_file,
        ),
        'get_common_sample': functools.partial(
            mobile_common_sample.get if config_for_training.IsMobile else common_sample.get,
            yt_client=yt_client,
            yql_client=yql_client,
            date=config_for_training.Date,
            common_train_sample=config_for_training.PathsInfo.CommonTrainSample,
            general_population=config_for_training.PathsInfo.GeneralPopulation,
        ),
        'get_training_sample_for_socdem_type': functools.partial(
            utils.wrap_with_nirvana_transaction(yt_client, training.get_training_sample_for_socdem_type),
            yt_client=yt_client,
            yql_client=yql_client,
            config_for_training=config_for_training,
            logger=logger,
        ),
        'train_nn_model': functools.partial(
            utils.wrap_with_nirvana_transaction(yt_client, training.train_nn_model),
            yt_client=yt_client,
            socdem_type=config_for_training.SocdemType,
            train_table_path=config_for_training.PathsInfo.RawTrainingSampleBySocdem,
            output_nn_model_file=nn_model_output_file,
            logger=logger,
        ),
    }

    job_name = parameters.get('job_name')
    logger.info('Job name: {}'.format(job_name))

    if job_name in tasks_dict:
        tasks_dict[job_name]()
    else:
        logger.warn('Unknown job_name "{}"'.format(job_name))
        exit(1)


if __name__ == '__main__':
    main()
