import json
import logging

from flask import current_app as app
from flask_script import Command, Option

from jafar.pipelines import ids
from jafar.pipelines.pipeline import PipelineConfig
from jafar.pipelines.predefined import all_predefined_pipelines
from jafar.storages.memmap.storage import MultiprocessMemmapStorage

logger = logging.getLogger(__name__)


def encode_string_values(params):
    for key in params:
        if isinstance(params[key], dict):
            encode_string_values(params[key])
        elif isinstance(params[key], unicode):
            params[key] = params[key].encode()


def train_pipeline(storage, pipeline, country, recommendation_mode='score',
                   pipeline_config=None, rename=None, top_n=None, cv=False, cv_output=None):
    logger.info('Training pipeline %s', pipeline)
    pipeline_creator = all_predefined_pipelines[pipeline]

    params = json.load(open(pipeline_config)) if pipeline_config else {}
    encode_string_values(params)
    config = PipelineConfig(recommendation_mode=recommendation_mode, online=False, cross_validation=cv)
    top_n = top_n or app.config['TOP_N_COUNT']
    pipeline = pipeline_creator(config, storage, top_n)

    if rename:
        logger.info("Pipeline %s renamed to %s", pipeline.name, rename)
        pipeline.name = rename

    params.pop('cross_validation', 0)
    pipeline.set_params(params)
    context = pipeline.train(country)
    if cv and cv_output:
        with open(cv_output, 'w') as fp:
            json.dump(context.data[ids.FRAME_KEY_CV_RESULTS].to_list_of_dicts()[0], fp)


class TrainRecommenders(Command):
    """
    Trains recommenders scheduled for daily re-training
    """

    option_list = (
        Option('--pipeline'),
        Option('--country', default='RU'),
        Option('--pipeline-config', dest='pipeline_config'),
        Option('--recommendation-mode', default='score',
               choices=PipelineConfig.recommendation_mode_choices),
        Option('--rename'),
        Option('--cv', action='store_true'),
        Option('--cv-output', dest='cv_output')
    )

    def run(self, pipeline, country=None, pipeline_config=None,
            recommendation_mode='score', rename=None, cv=False, cv_output=None):
        storage = MultiprocessMemmapStorage()
        assert pipeline in all_predefined_pipelines, "Unknown pipeline: %s" % pipeline

        countries = [country] if country else app.config['COUNTRIES']
        for country in countries:
            logger.info('Training pipeline: %s for country %s', pipeline, country)
            train_pipeline(
                storage=storage,
                pipeline=pipeline,
                country=country,
                recommendation_mode=recommendation_mode,
                pipeline_config=pipeline_config,
                rename=rename,
                cv=cv,
                cv_output=cv_output
            )

        logger.info('Done')
