import functools
import logging

import nirvana.job_context as nv

from crypta.lib.python.logging import logging_helpers
from crypta.lib.python.yql import yql_helpers
from crypta.lib.python.yt import yt_helpers
from crypta.profile.lib.socdem_helpers.inference_utils import (
    inference,
    voting,
)
from crypta.profile.services.validate_socdem_models.lib import utils
from crypta.profile.utils.config import config
from crypta.profile.utils.utils import get_socdem_thresholds_from_api


logger = logging.getLogger(__name__)


def main():
    logging_helpers.configure_stderr_logger(logging.getLogger(), level=logging.INFO)
    logger.info('Socdem models validation')

    context = nv.context()
    thresholds_dict = utils.get_thresholds_from_input(context)
    thresholds_output_filepath = utils.get_thresholds_filepath_from_output(context)
    parameters = context.get_parameters()
    validation_config = utils.get_proto_config(parameters, logger)

    yt_client = yt_helpers.get_yt_client(
        yt_proxy=validation_config.Yt.Proxy,
        yt_pool=validation_config.Yt.Pool,
        yt_token=validation_config.Yt.Token,
        remote_temp_tables_directory=config.PROFILES_TMP_YT_DIRECTORY,
    )
    yql_client = yql_helpers.create_yql_client(
        yt_proxy=validation_config.Yt.Proxy,
        token=validation_config.Yql.Token,
        pool=validation_config.Yt.Pool,
        tmp_folder=config.PROFILES_YQL_TMP_YT_DIRECTORY,
    )

    tasks_dict = {
        'get_nn_predictions': functools.partial(
            inference.get_nn_predictions_for_all_profiles,
            yt_client=yt_client,
            is_mobile=validation_config.IsMobile,
            neuro_raw_profiles=validation_config.PathsInfo.NeuroRawProfilesTable,
            logger=logger,
            resource_id=validation_config.ResourceId,
            date=validation_config.Date,  # to set update_time
            monthly=validation_config.UseMonthlyProfiles,
        ),
        'get_catboost_predictions': functools.partial(
            inference.get_catboost_predictions_for_all_profiles,
            yt_client=yt_client,
            yql_client=yql_client,
            is_mobile=validation_config.IsMobile,
            neuro_raw_profiles=validation_config.PathsInfo.NeuroRawProfilesTable,
            raw_profiles=(validation_config.PathsInfo.MobileRawClassificationProfilesTable if validation_config.IsMobile
                          else validation_config.PathsInfo.RawClassificationProfilesTable),
            logger=logger,
            resource_id=validation_config.ResourceId,
            date=validation_config.Date,  # to set update_time
            additional_features_table=(validation_config.PathsInfo.AdditionalFeaturesTable
                                       if validation_config.UseAdditionalFeatures else None),
            additional_features_number=validation_config.AdditionalFeaturesNumber,
        ),
        'run_voting': functools.partial(
            voting.run_voting,
            yt_client=yt_client,
            yql_client=yql_client,
            yandexuid_14days_raw_classification=utils.get_raw_profiles_table(
                yt_client=yt_client,
                validation_config=validation_config,
                is_mobile=False,
            ),
            devid_35days_raw_classification_table=utils.get_raw_profiles_table(
                yt_client=yt_client,
                validation_config=validation_config,
                is_mobile=True,
            ),
            thresholds=thresholds_dict or get_socdem_thresholds_from_api(),
            date=validation_config.Date,  # to separate active profiles (all profiles will be considered active)
            results_directory=validation_config.PathsInfo.VotingResultsDirectory,
        ),
        'calculate_new_thresholds': functools.partial(
            utils.calculate_new_thresholds,
            yt_client=yt_client,
            yql_client=yql_client,
            validation_config=validation_config,
            thresholds_output_filepath=thresholds_output_filepath,
        ),
        'get_final_predictions': functools.partial(
            utils.get_final_predictions,
            yt_client=yt_client,
            yql_client=yql_client,
            validation_config=validation_config,
            thresholds=thresholds_dict,
            logger=logger,
        ),
        'copy_to_validation_folder': functools.partial(
            utils.copy_predictions_to_custom_validation,
            yt_client=yt_client,
            validation_config=validation_config,
            logger=logger,
        ),
    }

    job_name = parameters.get('job_name')
    logger.info('Job name: {}'.format(job_name))

    if job_name in tasks_dict:
        tasks_dict[job_name]()
    else:
        logger.warn('Unknown job_name "{}"'.format(job_name))
        exit(1)


if __name__ == '__main__':
    main()
