import numpy as np

from crypta.lib.python.custom_ml.train_helpers.base_train_helper import BaseModelTrainHelper
from crypta.lib.python.custom_ml.tools.metrics import (
    classes_accuracy,
    sure_classes_ratio,
)
from crypta.profile.lib import (
    profile_helpers,
    vector_helpers,
)
from crypta.profile.lib.socdem_helpers import socdem_config
from crypta.profile.lib.socdem_helpers.tools import features as features_utils
from crypta.profile.lib.socdem_helpers.train_utils import models as models_utils
from crypta.profile.utils import utils


def get_socdem_ordered_thresholds(classification_thresholds):
    ordered_thresholds = {}

    for target_type in classification_thresholds:
        if target_type in socdem_config.yet_another_segment_names_by_label_type:
            ordered_thresholds[target_type] = []
            for class_label in socdem_config.yet_another_segment_names_by_label_type[target_type]:
                ordered_thresholds[target_type].append(classification_thresholds[target_type][class_label])

    return ordered_thresholds


def get_flat_cat_features_dict(nested_dict, cat_feature_keys):
    flat_dict = {}
    offset = 0
    for segment_type in cat_feature_keys:
        for segment_id, value in nested_dict[segment_type].items():
            if segment_type in features_utils.categorical_feature_name_to_keyword:
                feature_type = features_utils.categorical_feature_name_to_keyword[segment_type]
            else:
                feature_type = segment_type
            key = features_utils.get_feature_name(feature_type=feature_type, value=segment_id)
            flat_dict[key] = value + offset
        offset += len(nested_dict[segment_type])
    return flat_dict


class SocdemModelTrainHelper(BaseModelTrainHelper):
    """
    Class to train neural socdem models.
    """
    def __init__(self, yt, logger, vector_size=socdem_config.VECTOR_SIZE):
        super(SocdemModelTrainHelper, self).__init__(yt=yt, logger=logger, vector_size=vector_size)
        socdem_thresholds = utils.get_socdem_thresholds_from_api()
        self.ordered_thresholds = get_socdem_ordered_thresholds(socdem_thresholds)

    def download_sample_with_vectors(self, source_table, additional_columns, max_sample_size=None):
        """Download sample with features

        additional_columns: dict with name and data type
        """
        columns_to_read = {'vector'}
        columns_to_read.update(additional_columns.keys())

        source_table = self.yt.TablePath(source_table, end_index=max_sample_size, columns=list(columns_to_read))
        if max_sample_size is not None:
            n_records = min(self.yt.row_count(source_table), max_sample_size)
        else:
            n_records = self.yt.row_count(source_table)

        self.logger.info('Columns to read: {}'.format(columns_to_read))
        self.logger.info('Downloading {} records from {}'.format(n_records, source_table))

        features = np.zeros((n_records, self.vector_size), dtype=np.float32)
        additional_columns = self.init_additional_columns(additional_columns=additional_columns, n_records=n_records)

        for i, row in enumerate(self.yt.read_table(source_table)):
            if not (i % max(round(float(n_records) * 0.05), 1)):
                self.logger.info('%.1f%%' % (100 * (i + 1) / float(n_records)))

            features[i] = vector_helpers.vector_row_to_features(row)

            for column in additional_columns:
                if column in ('gender', 'age_segment', 'income_segment'):
                    if row[column] is None:
                        additional_columns[column][i] = None
                    else:
                        if column == 'gender':
                            additional_columns[column][i] = profile_helpers.gender_name_to_id[row[column]]
                        elif column == 'age_segment':
                            additional_columns[column][i] = profile_helpers.six_segment_age_name_to_id[row[column]]
                        elif column == 'income_segment':
                            additional_columns[column][i] = profile_helpers.five_segment_income_name_to_id[row[column]]
                else:
                    additional_columns[column][i] = row[column]

        self.logger.info('Downloading finished!')

        return features, additional_columns

    def log_metrics(self, socdem_type, metrics):
        for metric, value in metrics.items():
            self.logger.info('{} {}:{:.3f}'.format(socdem_type, metric, value))

    def prepare_sample(self, sample, socdem_type):
        features, columns = sample

        # normalize weights
        weight_column = '{}_weight'.format(socdem_type)
        if weight_column in columns:
            weights = columns[weight_column][~np.isnan(columns[weight_column])]
            factor = len(weights) / weights.sum()
            columns[weight_column] *= factor

        sample_size = np.isfinite(columns[socdem_type]).sum()
        datasets_metrics = {
            'sample.with_vectors.has_{}'.format(socdem_type): sample_size,
        }
        for target, class_name in enumerate(socdem_config.segment_names_by_label_type[socdem_type]):
            datasets_metrics['sample.with.vectors.{}_{}'.format(socdem_type, class_name)] = \
                1. * (columns[socdem_type] == target).sum() / sample_size

        datasets = self.get_train_test_datasets(
            features=features,
            columns=columns,
            target_columns=socdem_type,
            target_values_range=list(range(len(socdem_config.segment_names_by_label_type[socdem_type]))),
            weight_names=weight_column,
        )
        self.log_metrics(socdem_type, datasets_metrics)

        return datasets, datasets_metrics

    def train_neuro_model(self, socdem_type,
                          train_dataset, test_dataset,
                          neuro_offset=socdem_config.VECTOR_SIZE):
        from tensorflow.keras.callbacks import EarlyStopping
        from sklearn.metrics import log_loss

        socdem_segment = socdem_config.socdem_type_to_segment_name[socdem_type]
        nn = models_utils.get_custom_neuro_model(
            n_classes=len(socdem_config.segment_names_by_label_type[socdem_segment]),
            neuro_offset=neuro_offset,
        )

        x_train, y_train, weight_train = train_dataset
        x_test, y_test, weight_test = test_dataset

        nn.fit(
            x_train[:, :neuro_offset],
            y_train,
            sample_weight=weight_train,
            verbose=2,
            epochs=30,
            validation_split=0.1,
            batch_size=4096,
            callbacks=[EarlyStopping(patience=2)],
        )

        # Check numpy transformation correctness
        numpy_nn = models_utils.convert_simple_keras_to_numpy(nn)
        predicted_keras = nn.predict(x_test[:100, :neuro_offset])
        predicted_numpy = numpy_nn.predict(x_test[:100, :neuro_offset])
        assert np.allclose(predicted_keras, predicted_numpy, atol=1e-4), \
            'Income is wrong after conversion to numpy {} {}'.format(predicted_keras, predicted_numpy)

        # Reporting metrics
        predicted_train = nn.predict(x_train[:, :neuro_offset], batch_size=10000)
        predicted_test = nn.predict(x_test[:, :neuro_offset], batch_size=10000)

        nn_metrics = {
            '{}_train_logloss'.format(socdem_type): log_loss(
                y_train, predicted_train, sample_weight=weight_train,
            ),
            '{}_test_logloss'.format(socdem_type): log_loss(
                y_test, predicted_test, sample_weight=weight_test,
            ),
            '{}_sure_ratio'.format(socdem_type): sure_classes_ratio(
                predicted_test,
                self.ordered_thresholds[socdem_config.socdem_type_to_yet_another_segment_name[socdem_type]],
                sample_weight=weight_test,
            ),
            '{}_accuracy'.format(socdem_type): classes_accuracy(
                y_test,
                predicted_test,
                self.ordered_thresholds[socdem_config.socdem_type_to_yet_another_segment_name[socdem_type]],
                sample_weight=weight_test,
            )
        }

        self.log_metrics(socdem_type, nn_metrics)
        return nn, nn_metrics
