from collections import defaultdict
from datetime import datetime
from functools import partial
import json
import os

from library.python import resource
import numpy as np
import pandas as pd

from crypta.lib.python import (
    templater,
    yaml_config,
)
from crypta.lib.python.custom_ml.tools.metrics import pandas_to_startrek
from crypta.lib.python.proto_secrets import proto_secrets
from crypta.lib.python.yt import yt_helpers
from crypta.profile.lib.socdem_helpers import socdem_config
from crypta.profile.lib.socdem_helpers.inference_utils.thresholds import (
    calculate_thresholds,
    compute_socdem_distributions,
)
from crypta.profile.services.validate_socdem_models.proto.config_pb2 import TConfig
from crypta.profile.utils.config import config
from crypta.profile.utils.utils import get_exact_socdem_dict


def get_yaml_config(parameters):
    if 'date' not in parameters:
        parameters['date'] = datetime.fromtimestamp(int(parameters['timestamp'])).strftime('%Y-%m-%d')
    template_config = resource.find('/validation/config.yaml').decode('utf-8')
    rendered_config = templater.render_template(
        template_config,
        vars=parameters,
    )
    return rendered_config


def get_proto_config(parameters, logger):
    config_for_validation = TConfig()
    yaml_config.yaml2proto.yaml2proto(get_yaml_config(parameters), config_for_validation)
    logger.info('Config:\n{}'.format(proto_secrets.get_copy_without_secrets(config_for_validation)))
    return config_for_validation


def get_thresholds_from_input(nv_context):
    inputs = nv_context.get_inputs()

    if inputs.has('thresholds'):
        input_file = inputs.get('thresholds')
        with open(input_file, 'r') as dict_file:
            thresholds_dict = json.load(dict_file)
    else:
        thresholds_dict = None

    return thresholds_dict


def get_thresholds_filepath_from_output(nv_context):
    outputs = nv_context.get_outputs()
    thresholds_dict_output_filepath = None if not outputs.has('thresholds') else outputs.get('thresholds')
    return thresholds_dict_output_filepath


def get_raw_profiles_table(yt_client, validation_config, is_mobile=False):
    if is_mobile:
        custom_table_path = validation_config.PathsInfo.MobileRawClassificationProfilesTable
        default_table_path = config.MERGED_RAW_DEVID_PROFILES
    else:
        custom_table_path = validation_config.PathsInfo.RawClassificationProfilesTable
        default_table_path = config.MERGED_RAW_YANDEXUID_PROFILES

    if yt_client.exists(custom_table_path):
        return custom_table_path
    return default_table_path


def convert_thresholds(thresholds):
    formatted_thresholds = defaultdict(dict)
    for socdem_segment_type_name in socdem_config.SOCDEM_SEGMENT_TYPE_NAMES:
        for idx, socdem_segment in \
                enumerate(socdem_config.yet_another_segment_names_by_label_type[socdem_segment_type_name]):
            formatted_thresholds[socdem_segment_type_name][socdem_segment] = thresholds[socdem_segment_type_name][idx]

    if 'income_segments' not in formatted_thresholds:
        formatted_thresholds['income_segments'] = {'A': 0.5, 'B': 0.5, 'C': 0.5}

    return formatted_thresholds


def calculate_new_thresholds(yt_client, yql_client, validation_config, thresholds_output_filepath):
    new_thresholds = calculate_thresholds(
        yt_client=yt_client,
        yql_client=yql_client,
        profiles_table=os.path.join(validation_config.PathsInfo.CorrectedProfilesTable),
        histograms_directory=validation_config.PathsInfo.HistogramsDirectory,
        needed_recalls=socdem_config.needed_recalls,
        needed_total_recalls=socdem_config.needed_total_recalls,
    )

    with open(thresholds_output_filepath, 'w+') as dict_output_file:
        json.dump(convert_thresholds(new_thresholds), dict_output_file)


def add_exact_socdem(row, thresholds):
    row['exact_socdem'] = get_exact_socdem_dict(row, thresholds)
    yield row


def get_final_predictions(yt_client, yql_client, validation_config, thresholds, logger):
    schema = yt_helpers.get_yt_schema_dict_from_table(yt_client, validation_config.PathsInfo.CorrectedProfilesTable)
    schema['exact_socdem'] = 'any'

    yt_helpers.create_empty_table(
        yt_client=yt_client,
        path=validation_config.PathsInfo.FinalPredictionsTable,
        schema=schema,
    )

    yt_client.run_map(
        partial(add_exact_socdem, thresholds=thresholds),
        validation_config.PathsInfo.CorrectedProfilesTable,
        validation_config.PathsInfo.FinalPredictionsTable,
    )

    logger.info('Final prediction are calculated')

    compute_socdem_distributions(
        yql_client=yql_client,
        profiles_table=validation_config.PathsInfo.FinalPredictionsTable,
        distributions_table=validation_config.PathsInfo.SocdemDistributionsTable,
        use_last_active_profiles=False,
    )

    logger.info('Distributions are calculated')

    for socdem_distribution in yt_client.read_table(validation_config.PathsInfo.SocdemDistributionsTable):
        socdem_type = socdem_distribution['socdem_type']
        socdem_segment_name = socdem_config.socdem_type_to_yet_another_segment_name[socdem_type]

        found_distribution = socdem_distribution['distribution']
        found_distribution['total'] = np.sum(list(socdem_distribution['distribution'].values()))
        found_distribution['type'] = 'computed distribution'
        logger.info(found_distribution)

        needed_distribution = dict(zip(
            socdem_config.yet_another_segment_names_by_label_type[socdem_segment_name],
            np.array(socdem_config.needed_recalls[socdem_segment_name]) *
            socdem_config.needed_total_recalls[socdem_segment_name])
        )
        needed_distribution['total'] = socdem_config.needed_total_recalls[socdem_segment_name]
        needed_distribution['type'] = 'target distribution'
        logger.info(needed_distribution)

        logger.info('{} distribution:\n{}'.format(
            socdem_type,
            pandas_to_startrek(pd.DataFrame([needed_distribution, found_distribution])),
        ))


def copy_predictions_to_custom_validation(yt_client, validation_config, logger):
    validation_default_parameters = {
        'region': 'all',
        'processed': False,
        'date': validation_config.ValidationSampleDate,
    }

    for socdem_type, source in (('gender_age', 'passport'), ('income', 'delta_credit')):
        destination_table = os.path.join(
            validation_config.PathsInfo.CustomValidationInputDirectory,
            '{}_{}'.format(
                os.path.basename(os.path.basename(validation_config.PathsInfo.ValidationDirectory)),
                socdem_type,
            ),
        )
        yt_client.copy(
            validation_config.PathsInfo.FinalPredictionsTable,
            destination_table,
            force=True,
        )

        validation_default_parameters['source'] = source
        for parameter, value in validation_default_parameters.items():
            yt_client.set_attribute(destination_table, parameter, value)

        logger.info('Sample to calculate metrics for {} is saved to custom validation directory'.format(socdem_type))
