from abc import (
    ABCMeta,
    abstractmethod,
)
import itertools

import numpy as np
import pandas as pd

from crypta.lib.python.custom_ml.tools import training_utils
from crypta.lib.python.custom_ml.training_config import (
    LAB_SEGMENTS_INFO_TABLE,
    VALIDATION_SAMPLE_PERCENTAGE,
    VALIDATION_SAMPLE_REST,
    VECTOR_SIZE,
)


class BaseModelTrainHelper(object):
    __metaclass__ = ABCMeta

    def __init__(self, yt, logger, vector_size=VECTOR_SIZE):
        self.yt = yt
        self.logger = logger
        self.vector_size = vector_size

    @abstractmethod
    def download_sample_with_vectors(self, source_table, additional_columns, max_sample_size=None):
        """
        Download sample_with_vectors
        additional_columns: dict with name and data type
        """
        pass

    def get_lab_segments_df(self):
        lab_segments_df = pd.DataFrame(list(self.yt.read_table(LAB_SEGMENTS_INFO_TABLE)))
        lab_segments_df['exportKeywordId'] = lab_segments_df['exportKeywordId'].fillna(-1.0).astype(int).astype(str)
        lab_segments_df['exportSegmentId'] = lab_segments_df['exportSegmentId'].fillna(-1.0).astype(int).astype(str)
        return lab_segments_df

    @staticmethod
    def init_additional_columns(additional_columns, n_records):
        return {
            column_name: np.empty(n_records, dtype=column_type)
            for column_name, column_type in additional_columns.items()
        }

    @staticmethod
    def get_train_test_datasets(features, columns, target_columns, target_values_range, weight_names=None):
        """
        Function to prepare datasets for training and testing.

        Parameters:
            features: list of lists or numpy array of shape (number of examples, number of features)
            columns: dict
                {column name: array of shape (number of examples)}
            target_columns: str or tuple of str
                Names of columns that are considered to be targets.
            target_values_range: list of int or list of lists of int
                For each target column it is checked if all values from target_values_range are present
                in computed train and test datasets.
            weight_names: str or tuple of str
                Names of columns that are considered to be sample weights for different targets.
        """
        if isinstance(target_columns, str):
            target_columns = [target_columns]
            target_values_range = [target_values_range]
            weight_names = [weight_names]
        else:
            if weight_names is None:
                weight_names = [None] * len(target_columns)

        # check if given weight_name exist in columns
        for idx in range(len(weight_names)):
            if weight_names[idx] not in columns:
                weight_names[idx] = None

        has_target = np.full(len(features), False, dtype=bool)
        for target in target_columns:
            has_target = np.logical_or(
                has_target,
                ~np.isnan(columns[target]),
            )

        indices = {
            'train': columns['crypta_id'] % VALIDATION_SAMPLE_PERCENTAGE != VALIDATION_SAMPLE_REST,
            'test': columns['crypta_id'] % VALIDATION_SAMPLE_PERCENTAGE == VALIDATION_SAMPLE_REST,
        }

        if isinstance(features, list):
            datasets = {
                'train': [list(itertools.compress(features, indices['train'] & has_target))],
                'test': [list(itertools.compress(features, indices['test'] & has_target))],
            }
        else:
            datasets = {
                'train': [features[indices['train'] & has_target]],
                'test': [features[indices['test'] & has_target]],
            }

        for target, weight, values in zip(target_columns, weight_names, target_values_range):
            for dataset_type in ('train', 'test'):
                targets = columns[target][indices[dataset_type] & has_target]
                if values is not None:
                    training_utils.check_array_to_have_all_values(targets, values)
                    datasets[dataset_type].append(training_utils.to_categorical(targets))
                else:
                    datasets[dataset_type].append(targets)
                weights = np.nan_to_num(columns[weight][indices[dataset_type] & has_target]) if weight is not None \
                    else None
                datasets[dataset_type].append(weights)

        return datasets
