import functools
import json
import os
import tempfile

try:
    import cPickle as pickle
except:
    import pickle
import numpy as np

from crypta.lib.python.yt import yt_helpers
from crypta.profile.lib import date_helpers
from crypta.profile.lib import vector_helpers
from crypta.profile.lib.socdem_helpers import mobile_socdem as mobile_socdem_utils
from crypta.profile.lib.socdem_helpers import socdem as socdem_utils
from crypta.profile.lib.socdem_helpers import socdem_config
from crypta.profile.lib.socdem_helpers.simple_nn import SimpleNN
from crypta.profile.lib.socdem_helpers.tools.features import download_features_dict_from_sandbox
from crypta.profile.utils import utils
from crypta.profile.utils.config import config


apply_catboost_for_socdem_query_template = """
-- specify the URL(s) where to get the models
PRAGMA file(
    'age_model.bin',
    '{age_model}'
);
PRAGMA file(
    'gender_model.bin',
    '{gender_model}'
);
PRAGMA file(
    'income_model.bin',
    '{income_model}'
);

-- Initialize CatBoost FormulaEvaluator with given models:
$age_evaluator = CatBoost::LoadModel(
    FilePath('age_model.bin')
);

$gender_evaluator = CatBoost::LoadModel(
    FilePath('gender_model.bin')
);

$income_evaluator = CatBoost::LoadModel(
    FilePath('income_model.bin')
);

-- Prepare the data:
$data = (
    SELECT
        ListMap(Yson::ConvertToDoubleList(FloatFeatures), ($x) -> {{ RETURN Cast($x AS Float); }}) AS FloatFeatures,
        Yson::ConvertToStringList(CatFeatures) AS CatFeatures,
        PassThrough
    FROM `{input_table}`
);

-- Apply Catboost in batch mode
$age_processed = (
    PROCESS $data
    USING CatBoost::EvaluateBatch(
        $age_evaluator,
        TableRows(),
        6,  -- Number of classes
        4096  -- Batch size
    )
);

$gender_processed = (
    PROCESS $data
    USING CatBoost::EvaluateBatch(
        $gender_evaluator,
        TableRows(),
        2,  -- Number of classes
        4096  -- Batch size
    )
);

$income_processed = (
    PROCESS $data
    USING CatBoost::EvaluateBatch(
        $income_evaluator,
        TableRows(),
        5,  -- Number of classes
        4096  -- Batch size
    )
);

$age_schema = AsList('0_17', '18_24', '25_34', '35_44', '45_54', '55_99');
$gender_schema = AsList('m', 'f');
$income_schema = AsList('A', 'B1', 'B2', 'C1', 'C2');

$exponentiation = ($raw_weights) -> {{RETURN ListMap($raw_weights, ($weight) -> {{ RETURN Math::Exp($weight); }})}};
$normalize = ($weights) -> {{RETURN ListMap($weights, ($weight) -> {{ RETURN $weight / ListSum($weights); }})}};
$get_model_predictions = ($schema, $result) -> {{
    RETURN Yson::Serialize(Yson::FromDoubleDict(ToDict(ListZip($schema, $normalize($exponentiation($result))))));
}};
"""

socdem_classification_basic_schema = {
    'gender': 'any',
    'user_age_6s': 'any',
    'income_5_segments': 'any',
    'update_time': 'uint64',
}

socdem_devid_classification_schema = socdem_classification_basic_schema.copy()
socdem_devid_classification_schema['id'] = 'string'
socdem_devid_classification_schema['id_type'] = 'string'

socdem_yandexuid_classification_schema = socdem_classification_basic_schema.copy()
socdem_yandexuid_classification_schema['yandexuid'] = 'uint64'

catboost_sample_schema = {
    'FloatFeatures': 'any',
    'CatFeatures': 'any',
}


def filter_dict_by_keys(d, keys):
    """ Returns dict with only required keys """
    return {k: v for k, v in d.items() if k in keys}


def get_format_function(socdem_type, batch_predictions):
    segment_name = socdem_config.socdem_type_to_yet_another_segment_name[socdem_type]
    for predictions in batch_predictions:
        yield {
            segment_name: {
                socdem_config.yet_another_segment_names_by_label_type[segment_name][segment_id]: float(value)
                for segment_id, value in enumerate(predictions)
            }
        }


def get_simple_nn_model(resource, released, model_path, resource_id=None):
    with tempfile.NamedTemporaryFile() as model_file_to_save:
        utils.download_file_from_sandbox(resource, released, model_path, model_file_to_save.name, resource_id)
        with open(model_file_to_save.name, 'rb') as model_file_to_load:
            nn = pickle.load(model_file_to_load)
            assert isinstance(nn, SimpleNN)
            return nn


def get_nn_models_from_sandbox(resource, released, resource_id=None):
    nn_models = []

    for socdem_type in socdem_config.SOCDEM_TYPES:
        model_path = '{}_nn_model.bin'.format(socdem_type)
        nn = get_simple_nn_model(
            resource=resource,
            released=released,
            model_path=model_path,
            resource_id=resource_id,
        )
        nn_models.append((nn, functools.partial(get_format_function, socdem_type=socdem_type)))

    return nn_models


class BatchModelApplyerMapper(object):
    """Applying bunch of neural networks on the cluster

    Arguments:
        models_list: List of tuples (model, format_function)
            model: something with .predict method
            format_function: Should accept model.predict() output and yield dict {'column name': anything}
        batch_size: Batch size to use on cluster. Tune this for RAM uage/run time trade-off.
    """

    def __init__(self, models_list, additional_columns, update_time, batch_size=4096):
        self.models_list = models_list
        self.batch_size = batch_size
        self.batch = []
        self.update_time = update_time
        self.additional_columns = additional_columns

    def start(self):
        self.batch = []

    def __call__(self, row):
        self.batch.append(row)
        if len(self.batch) >= self.batch_size:
            for record in self.process_batch():
                yield record

    def finish(self):
        for record in self.process_batch():
            yield record

    def process_batch(self):
        n = len(self.batch)  # Final batch can be smaller
        output_rows = []
        features = np.zeros((n, socdem_config.VECTOR_SIZE), dtype=np.float32)
        for i, row in enumerate(self.batch):
            output_rows.append(filter_dict_by_keys(row, self.additional_columns))
            features[i] = vector_helpers.vector_row_to_features(row)

        predictions = [
            format_func(batch_predictions=model.predict(features, env=config.environment)) for model, format_func in self.models_list
        ]
        for i, prediction in enumerate(zip(*predictions)):
            output_rows[i]['update_time'] = self.update_time

            for result in prediction:
                output_rows[i].update(result)

            yield output_rows[i]

        self.batch = []


def get_socdem_nn_models(
    yt_client,
    resource_type=None,
    released='production',
    resource_id=None,
    hahn_folder=None,
):
    if hahn_folder is not None:
        nn_models = []
        for socdem_type in socdem_config.SOCDEM_TYPES:
            nn = pickle.load(yt_client.read_file('{}/{}_nn_model.bin'.format(hahn_folder, socdem_type)))
            nn_models.append((nn, functools.partial(get_format_function, socdem_type=socdem_type)))
        return nn_models

    elif resource_id is not None or resource_type is not None:
        return get_nn_models_from_sandbox(resource_type, released, resource_id)

    raise ValueError('Hahn directory, resource type or resource_id must be defined.')


def get_features_dict(
    yt_client,
    resource_type=None,
    released='production',
    resource_id=None,
    hahn_folder=None,
    file_name='cat_features_dict.json',
):
    if hahn_folder is not None:
        return json.loads(next(yt_client.read_file(os.path.join(hahn_folder, file_name))))

    elif resource_id is not None or resource_type is not None:
        return download_features_dict_from_sandbox(
            resource_type=resource_type,
            released=released,
            file_name=file_name,
            resource_id=resource_id,
        )

    else:
        raise ValueError('Hahn directory, resource type or resource_id need to be defined.')


def get_catboost_models_file_paths(
    socdem_type,
    resource_type=None,
    released='production',
    resource_id=None,
    hahn_folder=None,
):
    if hahn_folder is not None:
        return 'yt://hahn/{}/{}_catboost_model.bin'.format(hahn_folder, socdem_type)

    elif resource_id is not None:
        template_path = 'https://proxy.sandbox.yandex-team.ru/{resource_id}/{socdem_type}_catboost_model.bin'
        return template_path.format(
            resource_id=resource_id,
            socdem_type=socdem_type,
        )

    elif resource_type is not None:
        template_path = 'https://proxy.sandbox.yandex-team.ru/last/{resource_type}/'\
                        '{socdem_type}_catboost_model.bin?attrs={{"released":"{released}"}}'
        return template_path.format(
            resource_type=resource_type,
            socdem_type=socdem_type,
            released=released,
        )

    else:
        raise ValueError('Hahn directory, resource type or resource_id need to be defined.')


def setup_parameters_for_nn_inference(is_mobile, monthly=False):
    if is_mobile:
        return {
            'schema': socdem_devid_classification_schema,
            'id_columns': ['id', 'id_type'],
            'input_table': config.MONTHLY_DEVID2VEC if monthly else config.DAILY_DEVID2VEC,
        }
    return {
        'schema': socdem_yandexuid_classification_schema,
        'id_columns': ['yandexuid'],
        'input_table': config.MONTHLY_YANDEXUID2VEC if monthly else config.DAILY_YANDEXUID2VEC,
    }


def get_nn_predictions_for_all_profiles(
    yt_client,
    is_mobile,
    neuro_raw_profiles,
    logger,
    resource_type=None,
    released=None,
    resource_id=None,
    hahn_folder=None,
    date=None,
    monthly=False,
):
    if released is None:
        released = 'stable' if config.environment == 'production' else 'testing'
    if date is None:
        date = date_helpers.get_today_date_string()
    parameters = setup_parameters_for_nn_inference(is_mobile, monthly)

    nn_model_list = get_socdem_nn_models(yt_client, resource_type, released, resource_id, hahn_folder)

    yt_helpers.create_empty_table(
        yt_client,
        neuro_raw_profiles,
        schema=parameters['schema'],
    )
    operation = yt_client.run_map(
        BatchModelApplyerMapper(
            nn_model_list,
            parameters['id_columns'],
            update_time=date_helpers.from_utc_date_string_to_noon_timestamp(date),
        ),
        parameters['input_table'],
        neuro_raw_profiles,
        spec={'title': 'Vector classification'},
    )

    logger.info(json.dumps(operation.get_job_statistics()))
    yt_client.run_sort(neuro_raw_profiles, sort_by=parameters['id_columns'])


def prepare_features_for_catboost_inference(
    yt_client,
    is_mobile,
    cat_features_dict,
    neuro_raw_profiles,
    sample_for_catboost_classification,
    raw_profiles=None,
    additional_features_table=None,
    additional_features_number=0,
):
    catboost_sample_schema['PassThrough'] = 'any' if is_mobile else 'uint64'
    yt_helpers.create_empty_table(
        yt_client,
        sample_for_catboost_classification,
        schema=catboost_sample_schema,
    )

    additional_features_tables = [additional_features_table] if additional_features_table is not None else []

    if is_mobile:
        yt_client.run_reduce(
            mobile_socdem_utils.MakeCatboostFeatures(
                cat_features_dict=cat_features_dict,
                additional_features_number=additional_features_number,
            ),
            [
                neuro_raw_profiles,
                config.APP_BY_DEVID_DAILY_TABLE,
            ] + additional_features_tables,
            sample_for_catboost_classification,
            spec={
                'data_size_per_job': 128 * 1024 * 1024,
            },
            reduce_by=['id', 'id_type'],
        )
    else:
        yt_helpers.create_empty_table(
            yt_client,
            raw_profiles,
            schema=utils.yandexuid_classification_schema,
        )

        yt_client.run_reduce(
            socdem_utils.MakeCatboostFeatures(
                cat_features_dict=cat_features_dict,
                additional_features_number=additional_features_number,
            ),
            [
                neuro_raw_profiles,
                config.SEGMENTS_STORAGE_BY_YANDEXUID_TABLE,
                config.YANDEXUID_METRICS_MERGED_HITS_TABLE,
            ] + additional_features_tables,
            [
                raw_profiles,
                sample_for_catboost_classification,
            ],
            spec={
                'data_size_per_job': 128 * 1024 * 1024,
            },
            reduce_by=['yandexuid'],
        )


def get_catboost_predictions_for_all_profiles(
    yt_client,
    yql_client,
    is_mobile,
    neuro_raw_profiles,
    raw_profiles,
    logger,
    resource_type=None,
    released=None,
    resource_id=None,
    hahn_folder=None,
    date=None,
    additional_features_table=None,
    additional_features_number=0,
):
    if released is None:
        released = 'stable' if config.environment == 'production' else 'testing'
    if date is None:
        date = date_helpers.get_today_date_string()

    logger.info('Define catboost models paths')
    catboost_model_files = {}
    for socdem_type in socdem_config.SOCDEM_TYPES:
        catboost_model_files['{}_model'.format(socdem_type)] = get_catboost_models_file_paths(
            socdem_type=socdem_type,
            resource_type=resource_type,
            released=released,
            resource_id=resource_id,
            hahn_folder=hahn_folder,
        )

    with yt_client.Transaction() as transaction, yt_client.TempTable() as sample_for_catboost_classification:
        logger.info('Prepare catboost features')
        prepare_features_for_catboost_inference(
            yt_client=yt_client,
            is_mobile=is_mobile,
            cat_features_dict=get_features_dict(
                yt_client=yt_client,
                resource_type=resource_type,
                released=released,
                resource_id=resource_id,
                hahn_folder=hahn_folder,
                file_name='cat_features_dict.json',
            ),
            neuro_raw_profiles=neuro_raw_profiles,
            sample_for_catboost_classification=sample_for_catboost_classification,
            raw_profiles=raw_profiles,
            additional_features_table=additional_features_table,
            additional_features_number=additional_features_number,
        )

        logger.info('Apply catboost models to profiles')

        results_processing_query_template = mobile_socdem_utils.catboost_application_result_processing_query_template \
            if is_mobile else socdem_utils.catboost_application_results_processing_query_template

        apply_catboost_models_query = '{}\n{}'.format(
            apply_catboost_for_socdem_query_template.format(
                input_table=sample_for_catboost_classification,
                gender_model=catboost_model_files['gender_model'],
                age_model=catboost_model_files['age_model'],
                income_model=catboost_model_files['income_model'],
            ),
            results_processing_query_template.format(
                update_time=date_helpers.from_utc_date_string_to_noon_timestamp(date),
                output_table=raw_profiles,
            ),
        )
        yql_client.execute(
            apply_catboost_models_query,
            title='YQL Apply catboost models to profiles',
            transaction=str(transaction.transaction_id),
        )

        yt_client.run_sort(raw_profiles, sort_by=['id', 'id_type'] if is_mobile else ['yandexuid'])
        yt_client.set_attribute(
            raw_profiles,
            'generate_date',
            date,
        )
