import logging

import numpy as np

from jafar.pipelines.blocks.composite import CompositeBlock
from jafar.pipelines.context import EmptyPipelineContext
from jafar.utils.structarrays import DataFrame

MISSING_ID = -1
MISSING_USER = '00000000-0000-0000-0000-000000000000'
MISSING_ITEM = ''

logger = logging.getLogger(__name__)


class MappingBlock(CompositeBlock):
    """
    Mapping block converts "user" and "item" field values
    to integers and back (essentially the same as sklearn's LabelEncoder).
    """

    def __init__(self, nested_blocks, frames,
                 user_columns=('user',), item_columns=('item',),
                 blocks_for_missing_items=None):
        """
        :param frames: list of the frame keys should be mapped
        """
        super(MappingBlock, self).__init__()
        self.nested_blocks = nested_blocks
        self.frames = frames
        self.user_columns = user_columns
        self.item_columns = item_columns
        self.blocks_for_missing_items = blocks_for_missing_items

    def get_blocks(self):
        return self.nested_blocks

    def get_mapping_frame_keys(self, context):
        return [key for key in context.data if key in self.frames]

    @staticmethod
    def map_columns(array, columns, storage, mapping_key, reverse, missing, filter_missing, return_missing):
        columns = [column for column in columns if column in array.dtype.names]
        if len(columns) > 0:
            # collect indices of rows that cannot be mapped
            missing_mask = np.full(len(array), False, dtype=np.bool)
            mapped_values_dict = {}
            for column in columns:
                mapped_values = storage.map_values(mapping_key=mapping_key, values=array[column],
                                                   default=missing, reverse=reverse)
                mapped_values = np.array(mapped_values, dtype=np.object if reverse else np.int32)
                missing_mask |= (mapped_values == missing)
                # save mapped columns to apply later
                mapped_values_dict[column] = mapped_values

            # get missing array from initial array before replacing ID columns
            missing_array = array[missing_mask] if return_missing else None

            # apply mapped columns
            for column in columns:

                array = array.replace_column(mapped_values_dict[column], column)
            array = array[~missing_mask] if filter_missing else array
        else:
            missing_array = DataFrame.from_structarray(np.array([], dtype=array.dtype))

        return (array, missing_array) if return_missing else array

    @staticmethod
    def iterate_columns(column_set, frames):
        """
        For each frame and for each column name, yields
        corresponding column values, if column is present.
        """
        for frame in frames:
            for column in column_set:
                if column in frame.dtype.names:
                    yield frame[column]

    def apply_init(self, context, train):
        keep_missing = self.blocks_for_missing_items is not None and not train
        item_map_key = self.key_for(context, 'item_map')
        user_map_key = self.key_for(context, 'user_map')
        if train:
            for column_set, map_key in [(self.user_columns, user_map_key), (self.item_columns, item_map_key)]:
                frames = (context.data[key] for key in self.get_mapping_frame_keys(context))
                # collect unique union of each user column from each user frame
                # to make a comprehensive mapping
                values = np.unique(reduce(np.union1d, self.iterate_columns(column_set, frames)))
                if len(values) > 0:
                    context.storage.make_mapping(map_key, values.astype(np.str))

        user_map = context.storage.get_proxy(user_map_key)
        item_map = context.storage.get_proxy(item_map_key)
        if train:
            context.storage.store(self.key_for(context, 'n_users'), np.array([len(user_map)]))
            context.storage.store(self.key_for(context, 'n_items'), np.array([len(item_map)]))

        missing_arrays = {}
        for key in self.get_mapping_frame_keys(context):
            if not isinstance(context.data[key], DataFrame):
                raise Exception("%s is not dataframe" % key)
            context.data[key], missing_arrays[key] = self.map_columns(
                context.data[key], columns=self.item_columns, storage=context.storage,
                mapping_key=item_map_key, reverse=False, missing=MISSING_ID,
                filter_missing=True, return_missing=True)

            # Missing users are not filtered because we need to make recommendations for unknown user
            context.data[key] = self.map_columns(
                context.data[key], columns=self.user_columns, storage=context.storage,
                mapping_key=user_map_key, reverse=False, missing=MISSING_ID,
                filter_missing=False, return_missing=False)

        contexts_with_blocks = [(context, self.nested_blocks)]

        # create parallel context with missing items only
        if keep_missing:
            missing_context = context.copy(with_data=False)
            for key, array in context.data.iteritems():
                missing_array = missing_arrays.get(key)
                if missing_array is None or (set(missing_array.dtype.names) & set(self.item_columns)) != set(self.item_columns):
                    # if there is no items in the frame - duplicate it for 'missing context'
                    missing_context.data[key] = array.copy()
                else:
                    missing_context.data[key] = missing_array

            contexts_with_blocks.append((missing_context, self.blocks_for_missing_items))
        return contexts_with_blocks

    def apply_complete(self, contexts, train):
        if train:
            logger.info('Do not map users and items from indices at training')
            return contexts[0]

        keep_missing = self.blocks_for_missing_items is not None and not train
        assert len(contexts) == 2 if keep_missing else 1
        context = contexts[0]
        if context is EmptyPipelineContext:
            return context

        item_map_key = self.key_for(context, 'item_map')
        user_map_key = self.key_for(context, 'user_map')
        for key in self.get_mapping_frame_keys(context):
            context.data[key] = self.map_columns(
                context.data[key], columns=self.item_columns, storage=context.storage,
                mapping_key=item_map_key, reverse=True, missing=MISSING_ITEM,
                filter_missing=True, return_missing=False)

            context.data[key] = self.map_columns(
                context.data[key], columns=self.user_columns, storage=context.storage,
                mapping_key=user_map_key, reverse=True, missing=MISSING_USER,
                filter_missing=False, return_missing=False)

            if keep_missing:
                missing_context = contexts[1]
                missing_array = missing_context.data.get(key)
                if missing_array is not None:
                    context.data[key] = DataFrame.concatenate([context.data[key], missing_array])

        return context
