from functools import partial

import numpy as np
from yt.wrapper import (
    create_table_switch,
    with_context,
)

from crypta.lib.python.custom_ml.tools.training_utils import to_categorical
from crypta.lib.python.yt import yt_helpers
from crypta.profile.lib import vector_helpers
from crypta.profile.lib.socdem_helpers import socdem_config
from crypta.profile.lib.socdem_helpers.mobile_socdem import weighted_mobile_training_sample_schema
from crypta.profile.lib.socdem_helpers.socdem import weighted_socdem_sample_schema
from crypta.profile.lib.socdem_helpers.train_utils.models import get_custom_neuro_model


@with_context
def make_ipwe_sample_mapper(row, context):
    if context.table_index == 0:
        row['population'] = 0
    else:
        row['population'] = 1
    yield row


class WeightsTrainer(object):
    def __init__(self, yt_client):
        self.yt_client = yt_client

    @staticmethod
    def train_ipwe_model(features, target_population, random_seed=None):
        from sklearn.model_selection import train_test_split
        from sklearn.metrics import roc_auc_score
        from tensorflow.keras.callbacks import EarlyStopping

        x_train, x_test, y_train, y_test = train_test_split(
            features, target_population, test_size=0.1, stratify=target_population,
        )

        nn = get_custom_neuro_model(n_classes=2, random_seed=random_seed)
        nn.fit(
            x_train,
            to_categorical(y_train),
            verbose=2,
            epochs=30,
            validation_split=0.1,
            batch_size=4096,
            callbacks=[EarlyStopping(patience=2)],
        )

        y_pred = nn.predict(x_test)
        ipwe_roc_auc = roc_auc_score(to_categorical(y_test), y_pred)

        return nn, ipwe_roc_auc

    def download_ipwe_sample(self, source_table):
        """
        Download sample for IPWE training or application.
        """
        rows, target_population = [], []
        n_records = self.yt_client.row_count(source_table)
        features = np.zeros((n_records, 512), dtype=np.float32)

        for i, row in enumerate(self.yt_client.read_table(source_table)):
            features[i] = vector_helpers.vector_row_to_features(row)
            if row.get('population') is not None:
                target_population.append(row['population'])
            rows.append(row)

        assert len(target_population) == 0 or len(target_population) == n_records, \
            'Target population list is not empty but not equal to number of table records.'

        return features, np.array(target_population), rows

    def compute_income_weights(self, ipwe_sample, from_income_to_income_sample, destination_sample_table,
                               random_seed=None):
        ipwe_train_vectors, target_population, _ = self.download_ipwe_sample(ipwe_sample)
        nn_ipwe, ipwe_roc_auc = self.train_ipwe_model(ipwe_train_vectors, target_population, random_seed=random_seed)

        income_train_vectors, _, income_train_rows = self.download_ipwe_sample(from_income_to_income_sample)
        normality_probabilities = nn_ipwe.predict(income_train_vectors)[:, 1]

        for i, row in enumerate(income_train_rows):
            row['income_segment_weight'] = float(normality_probabilities[i])

        self.yt_client.write_table(self.yt_client.TablePath(destination_sample_table, append=True), income_train_rows)

        return ipwe_roc_auc


def normalize_dict(dictionary):
    values_sum = sum(dictionary.values())
    if values_sum:
        return {k: float(v) / values_sum for k, v in dictionary.items()}
    else:
        return dictionary


def filter_without_income(row):
    if row['income_segment'] is None:
        yield create_table_switch(0)
        yield row
    else:
        yield create_table_switch(1)
        yield row


def divide_source_table(row, test_size):
    import random

    if random.random() > test_size:
        yield create_table_switch(0)
        yield row
    else:
        yield create_table_switch(1)
        yield row


def add_weights_to_training_sample(
    yt_client,
    logger,
    sample_without_weights,
    sample_with_weights,
    general_population_table=None,
    is_mobile=False,
    random_seed=None,
):
    """
    Create training sample with weights for income.
    Note: rows without income will be saved in the table without income weights.
    """
    with yt_client.TempTable() as income_training_sample, \
            yt_client.TempTable() as from_income_to_income_sample, \
            yt_client.TempTable() as from_income_to_ipwe_sample, \
            yt_client.TempTable() as ipwe_population_table, \
            yt_client.TempTable() as ipwe_sample:

        schema_with_weights = weighted_mobile_training_sample_schema if is_mobile else weighted_socdem_sample_schema
        ids_columns = ('id', 'id_type') if is_mobile else ('yandexuid',)

        logger.info('Initial sample size: {}'.format(yt_client.row_count(sample_without_weights)))
        yt_helpers.create_empty_table(
            yt_client=yt_client,
            path=sample_with_weights,
            schema=schema_with_weights,
        )

        yt_client.run_map(
            filter_without_income,
            source_table=sample_without_weights,
            destination_table=[
                sample_with_weights,
                income_training_sample,
            ],
        )

        logger.info('Training sample size: {}'.format(yt_client.row_count(income_training_sample)))

        yt_helpers.make_sample_with_size(
            yt_client=yt_client,
            source_table=income_training_sample,
            destination_table=income_training_sample,
            size=socdem_config.INCOME_SAMPLING_SIZE,
        )

        if yt_client.row_count(income_training_sample) > 1e4:
            ipwe_sampling_rate = 0.03
        else:
            ipwe_sampling_rate = 0.5
        yt_client.run_map(
            partial(divide_source_table, test_size=ipwe_sampling_rate),
            source_table=income_training_sample,
            destination_table=[
                from_income_to_income_sample,
                from_income_to_ipwe_sample,
            ],
        )
        logger.info('Examples without income: {}'.format(yt_client.row_count(sample_with_weights)))
        logger.info('Examples in final income training sample: {}'.format(
            yt_client.row_count(from_income_to_income_sample)),
        )

        columns_to_filter = ids_columns + ('vector',)
        if general_population_table is None:
            general_population_table = sample_with_weights
            logger.info('Samples without weights from table {} will be used as general population'.format(sample_without_weights))

        yt_helpers.make_sample_with_size(
            yt_client=yt_client,
            source_table=yt_client.TablePath(general_population_table, columns=columns_to_filter),
            destination_table=ipwe_population_table,
            size=yt_client.row_count(from_income_to_ipwe_sample),
        )

        logger.info('Examples of general population: {}'.format(
            yt_client.row_count(ipwe_population_table))
        )

        yt_client.run_map(
            make_ipwe_sample_mapper,
            [
                yt_client.TablePath(from_income_to_ipwe_sample, columns=columns_to_filter),
                ipwe_population_table,
            ],
            ipwe_sample,
        )
        yt_client.run_sort(ipwe_sample, sort_by=ids_columns)

        trainer = WeightsTrainer(yt_client=yt_client)
        ipwe_roc_auc = trainer.compute_income_weights(
            ipwe_sample,
            from_income_to_income_sample,
            sample_with_weights,
            random_seed=random_seed,
        )
        logger.info('IPWE roc auc on validation: {}'.format(ipwe_roc_auc))

        yt_client.run_sort(sample_with_weights, sort_by=ids_columns)

        logger.info('Sample building process has finished successfully')

        return ipwe_roc_auc
