import functools
import logging

from crypta.affinitive_geo.services.org_embeddings.lib import (
    banner_rt_sadovaya_vectors,
    banner_shows_clicks,
    bases,
    dssm_features,
    dssm_vectors,
    org_affinitive_banners,
    org_weights,
    orgvisits_for_description,
    regions_for_description,
    user_data_stats
)
from crypta.affinitive_geo.services.org_embeddings.lib.utils import (
    config,
    utils,
)
from crypta.lib.python.logging import logging_helpers
from crypta.lib.python.yql import yql_helpers
from crypta.lib.python.yt import yt_helpers

import nirvana.job_context as nv


logger = logging.getLogger(__name__)


def main():
    logging_helpers.configure_stderr_logger(logging.getLogger(), level=logging.INFO)
    logger.info('Affinitive geo org embeddings')

    context = nv.context()
    parameters = context.parameters

    yt_client = yt_helpers.get_yt_client_from_nv_parameters(nv_parameters=parameters)
    yql_client = yql_helpers.get_yql_client_from_nv_parameters(nv_parameters=parameters)
    date = parameters.get('date')
    custom_output_dir = parameters.get('custom-output-dir') if parameters.get('custom-output-dir') != '' else None
    if custom_output_dir is not None:
        utils.update_config_for_custom_output_dir(custom_output_dir)

    tasks_dict = {
        'calculate_org_weights': functools.partial(
            org_weights.calculate,
            yt_client=yt_client,
            yql_client=yql_client,
            date=date,
        ),
        'get_banner_rt_sadovaya_vectors': functools.partial(
            banner_rt_sadovaya_vectors.get,
            yt_client=yt_client,
            yql_client=yql_client,
            date=date,
        ),
        'get_banner_shows_clicks': functools.partial(
            banner_shows_clicks.get,
            yt_client=yt_client,
            yql_client=yql_client,
            date=date,
        ),
        'get_org_affinitive_banners': functools.partial(
            org_affinitive_banners.get,
            yt_client=yt_client,
            yql_client=yql_client,
            date=date,
        ),
        'get_orgs_dssm_features': functools.partial(
            dssm_features.get,
            yt_client=yt_client,
            input_table=config.ORGS_USER_DATA_STATS_TABLE,
            output_table=config.ORGS_DSSM_FEATURES_TABLE,
        ),
        'get_orgs_dssm_vectors': functools.partial(
            dssm_vectors.get,
            yt_client=yt_client,
            yql_client=yql_client,
            input_table=config.ORGS_DSSM_FEATURES_TABLE,
            output_table=config.ORGS_DSSM_VECTORS_TABLE,
        ),
        'get_orgs_user_data_stats': functools.partial(
            user_data_stats.get,
            yt_client=yt_client,
            input_table=config.ORGVISITS_FOR_DESCRIPTION_TABLE,
            output_table=config.ORGS_USER_DATA_STATS_TABLE,
        ),
        'get_orgvisits_for_description': functools.partial(
            orgvisits_for_description.get,
            yt_client=yt_client,
            yql_client=yql_client,
            date=date,
        ),
        'get_regions_dssm_features': functools.partial(
            dssm_features.get,
            yt_client=yt_client,
            input_table=config.REGIONS_USER_DATA_STATS_TABLE,
            output_table=config.REGIONS_DSSM_FEATURES_TABLE,
        ),
        'get_regions_dssm_vectors': functools.partial(
            dssm_vectors.get,
            yt_client=yt_client,
            yql_client=yql_client,
            input_table=config.REGIONS_DSSM_FEATURES_TABLE,
            output_table=config.REGIONS_DSSM_VECTORS_TABLE,
        ),
        'get_regions_for_description': functools.partial(
            regions_for_description.get,
            yt_client=yt_client,
            yql_client=yql_client,
        ),
        'get_regions_user_data_stats': functools.partial(
            user_data_stats.get,
            yt_client=yt_client,
            input_table=config.REGIONS_FOR_DESCRIPTION_TABLE,
            output_table=config.REGIONS_USER_DATA_STATS_TABLE,
        ),
        'make_bases': functools.partial(
            bases.make,
            yt_client=yt_client,
            yql_client=yql_client,
            date=date,
        ),
    }

    job_name = parameters.get('job_name')
    logger.info('Job name: {}'.format(job_name))

    if job_name in tasks_dict:
        tasks_dict[job_name]()
    else:
        logger.warn('Unknown job_name "{}"'.format(job_name))
        exit(1)


if __name__ == '__main__':
    main()
