import itertools
import os

try:
    import cPickle as pickle
except:
    import pickle
import numpy as np
import six

from crypta.lib.python.bigb_catboost_applier.train_sample import create_cd_file
from crypta.lib.python.custom_ml.tools.metrics import calculate_base_metrics
from crypta.profile.lib.socdem_helpers import socdem_config
from crypta.profile.lib.socdem_helpers.mobile_socdem import (
    MakeCatboostTrainingFeatures as MakeMobileCatboostTrainingFeatures,
)
from crypta.profile.lib.socdem_helpers.socdem import MakeCatboostTrainingFeatures
from crypta.profile.lib.socdem_helpers.tools import features as features_utils
from crypta.profile.lib.socdem_helpers.train_utils.models import convert_simple_keras_to_numpy
from crypta.profile.lib.socdem_helpers.train_utils.train_helper import (
    get_socdem_ordered_thresholds,
    SocdemModelTrainHelper,
)
from crypta.profile.lib.socdem_helpers.train_utils.weights_utils import add_weights_to_training_sample
from crypta.profile.utils import utils

filter_training_sample_by_socdem_type = """
INSERT INTO `{output_table}`
WITH TRUNCATE

SELECT *
FROM `{input_table}`
WHERE {socdem_segment_type} IS NOT NULL;
"""


def train_nn_model(yt_client, socdem_type, train_table_path, output_nn_model_file, logger, transaction=None):
    """Function to train and save neural model for selected socdem type."""
    with yt_client.Transaction(transaction_id=transaction.transaction_id if transaction is not None else None):
        socdem_segment = socdem_config.socdem_type_to_segment_name[socdem_type]
        additional_columns_to_read = {
            'crypta_id': np.uint64,
            socdem_segment: np.float32,
        }
        if features_utils.check_weight_column_in_table(yt_client, socdem_type, train_table_path):
            additional_columns_to_read[features_utils.get_weight_column(socdem_type)] = np.float32

        train_helper = SocdemModelTrainHelper(yt_client, logger)
        sample = train_helper.download_sample_with_vectors(
            source_table=train_table_path,
            additional_columns=additional_columns_to_read,
            max_sample_size=socdem_config.MAX_SAMPLE_SIZE_TO_TRAIN_NN_MODEL,
        )

        datasets, _ = train_helper.prepare_sample(sample, socdem_segment)
        logger.info('Train/Test datasets are prepared')

        nn_model, _ = train_helper.train_neuro_model(socdem_type, datasets['train'], datasets['test'])
        logger.info('NN model fitting finished')

        if output_nn_model_file is not None:
            with open(output_nn_model_file, 'wb') as model_file:
                pickle.dump(convert_simple_keras_to_numpy(nn_model), model_file, protocol=2)


def calculate_and_send_metrics(
    yt_client,
    socdem_type,
    table_path,
    training=False,
    is_mobile=False,
    send_metrics=True,
    use_thresholds=True,
    transaction=None,
):
    with yt_client.Transaction(transaction_id=transaction.transaction_id if transaction is not None else None):
        sample_size = yt_client.row_count(table_path)
        socdem_segment = socdem_config.socdem_type_to_segment_name[socdem_type]
        classes_number = len(socdem_config.segment_names_by_label_type[socdem_segment])
        labels = np.zeros((sample_size, classes_number))
        predictions = np.zeros((sample_size, classes_number))

        weights = np.zeros(sample_size) if 'Weight' in features_utils.get_schema_dict_from_table(yt_client, table_path) \
            else None
        for row_idx, row in enumerate(yt_client.read_table(table_path)):
            labels[row_idx][int(row['Label'])] = 1
            for class_idx in range(classes_number):
                predictions[row_idx][class_idx] = row['Probability:Class={}'.format(class_idx)]
            if weights is not None:
                weights[row_idx] = row['Weight']

        if use_thresholds:
            socdem_thresholds = get_socdem_ordered_thresholds(utils.get_socdem_thresholds_from_api())
            thresholds = socdem_thresholds[socdem_config.socdem_type_to_yet_another_segment_name[socdem_type]]
        else:
            thresholds = [0.5] * classes_number

        metrics_to_send = calculate_base_metrics(
            target_type=socdem_type,
            thresholds=thresholds,
            predictions=predictions,
            labels=labels,
            training=training,
            weights=weights,
            is_socdem=True,
            is_mobile=is_mobile,
        )
        if send_metrics:
            utils.report_ml_metrics_to_solomon(
                service=socdem_config.SOLOMON_SERVICE,
                metrics_to_send=metrics_to_send,
            )
        else:
            six.print_(metrics_to_send)


def get_training_sample_for_socdem_type(
    yt_client,
    yql_client,
    config_for_training,
    logger,
    transaction=None,
):
    """
    Create sample for particular socdem type with/without weights from common sample in
        [socdem/mobile_socdem] directory.
    """
    with yt_client.Transaction(transaction_id=transaction.transaction_id if transaction is not None else None) as transaction, \
            yt_client.TempTable() as sample_without_weights:
        if config_for_training.SocdemType == 'income':
            destination_table = sample_without_weights
        else:
            destination_table = config_for_training.PathsInfo.RawTrainingSampleBySocdem

        yql_client.execute(
            query=filter_training_sample_by_socdem_type.format(
                input_table=config_for_training.PathsInfo.CommonTrainSample,
                output_table=destination_table,
                socdem_segment_type=socdem_config.socdem_type_to_segment_name[config_for_training.SocdemType],
            ),
            transaction=str(transaction.transaction_id),
            title='YQL get training sample by socdem type',
        )

        if config_for_training.SocdemType == 'income':
            ipwe_roc_auc = add_weights_to_training_sample(
                yt_client,
                logger,
                destination_table,
                config_for_training.PathsInfo.RawTrainingSampleBySocdem,
                config_for_training.PathsInfo.GeneralPopulation,
                is_mobile=config_for_training.IsMobile,
            )

            if config_for_training.MetricsInfo.SendMetrics:
                utils.report_ml_metrics_to_solomon(
                    service=socdem_config.SOLOMON_SERVICE,
                    metrics_to_send=[{
                        'labels': {
                            'sample': 'training',
                            'socdem': config_for_training.SocdemType,
                            'vectors': 'mobile' if config_for_training.IsMobile else 'web',
                            'metric': 'ipwe_roc_auc',
                        },
                        'value': ipwe_roc_auc,
                    }],
                )

        yt_client.set_attribute(
            config_for_training.PathsInfo.RawTrainingSampleBySocdem,
            'generate_date',
            config_for_training.Date,
        )


def get_catboost_tables_for_training(
    yt_client,
    nn_models,
    features_dict,
    config_for_training,
    transaction=None,
):
    """Function to create train/test sample and features description file for catboost model."""

    with yt_client.Transaction(transaction_id=transaction.transaction_id if transaction is not None else None):
        if config_for_training.UseAdditionalFeatures:
            additional_features_description = features_utils.get_additional_features_description(
                yt_client,
                os.path.dirname(config_for_training.PathsInfo.RawTrainingSampleBySocdem),
            )
        else:
            additional_features_description = []

        has_weights = features_utils.check_weight_column_in_table(
            yt_client,
            config_for_training.SocdemType,
            config_for_training.PathsInfo.RawTrainingSampleBySocdem,
        )

        make_catboost_training_features_mapper = MakeMobileCatboostTrainingFeatures if config_for_training.IsMobile else \
            MakeCatboostTrainingFeatures
        yt_client.run_map(
            make_catboost_training_features_mapper(
                socdem_type=config_for_training.SocdemType,
                models_list=nn_models,
                flat_features_dict=features_dict,
                has_weights=has_weights,
                additional_features_number=len(additional_features_description),
                batch_size=1024,
            ),
            config_for_training.PathsInfo.RawTrainingSampleBySocdem,
            [
                config_for_training.PathsInfo.CatboostTrainSampleBySocdem,
                config_for_training.PathsInfo.CatboostTestSampleBySocdem,
            ],
            spec={
                'title': 'Prepare {} samples for catboost training'.format(config_for_training.SocdemType),
            },
        )

        socdem_features_description = features_utils.get_nn_output_features_description()
        if config_for_training.IsMobile:
            cat_features_description = features_utils.get_mobile_features_description(yt_client, features_dict)
        else:
            cat_features_description = features_utils.get_features_description(yt_client, features_dict)

        create_cd_file(
            yt_client=yt_client,
            path=config_for_training.PathsInfo.CatboostFeaturesBySocdem,
            float_features_description=list(itertools.chain(
                socdem_features_description,
                cat_features_description,
                additional_features_description,
            )),
            has_weight=has_weights,
        )

        for table in (config_for_training.PathsInfo.CatboostTrainSampleBySocdem,
                      config_for_training.PathsInfo.CatboostTestSampleBySocdem):
            yt_client.set_attribute(table, 'generate_date', config_for_training.Date)

        yt_client.set_attribute(
            config_for_training.PathsInfo.CatboostFeaturesBySocdem,
            'generate_date',
            config_for_training.Date,
        )
