import logging

import numpy as np
import pandas as pd

from jafar.pipelines.blocks import SingleContextBlock
from jafar.utils.structarrays import DataFrame, get_null_value

logger = logging.getLogger(__name__)


class JoinBlock(SingleContextBlock):
    """
    Using pandas.merge for creating merged array.
    This is the most flexible way, but sometimes slow and leads to memory leak
    """

    def __init__(self, left_frame, right_frame, result_frame, join_columns, how='outer',
                 left_columns=None, right_columns=None, ):
        super(JoinBlock, self).__init__(
            input_data=[left_frame, right_frame], output_data=None, destroyed_data=None
        )
        self.left_frame = left_frame
        self.right_frame = right_frame
        self.result_frame = result_frame
        self.join_columns = join_columns
        self.how = how
        self.left_columns = left_columns
        self.right_columns = right_columns

    def merge_arrays(self, left_array, right_array):
        left = left_array.to_pandas()
        right = right_array.to_pandas()
        return DataFrame.from_pandas(pd.merge(left, right, on=self.join_columns, how=self.how))

    def get_merging_arrays(self, context):
        right_array = context.data[self.right_frame]
        if self.right_columns:
            right_array = right_array[list(set(self.join_columns) | set(self.right_columns))]
        left_array = context.data[self.left_frame]
        if self.left_columns:
            left_array = left_array[list(set(self.join_columns) | set(self.left_columns))]
        return left_array, right_array

    def apply(self, context, train):
        left_array, right_array = self.get_merging_arrays(context)
        context.data[self.result_frame] = self.merge_arrays(left_array, right_array)
        return context


class LeftHashJoinBlock(JoinBlock):
    """
    Left merge of two arrays. Right array assumed to have unique values in join by columns.
    Merging is done by creating a dict from right array.
    """

    def __init__(self, overwrite=False, **kwargs):
        super(LeftHashJoinBlock, self).__init__(**kwargs)
        assert self.how in ('left', 'inner'), 'LeftHashJoinBlock supports only left and inner joins'
        self.overwrite = overwrite

    def get_merging_arrays(self, context):
        """ Do not make any copy here """
        return context.data[self.left_frame], context.data[self.right_frame]

    def merge_arrays(self, left_array, right_array):
        left_columns, right_columns = self.get_columns(left_array, right_array)
        result_columns = left_columns + right_columns

        hash_map = self.build_hash_map(right_array)
        left_keys = self.get_hash_map_keys(left_array)

        right_indexes = np.fromiter((hash_map.get(key, -1) for key in left_keys), dtype=np.int32)
        available_in_right = right_indexes >= 0

        resulting_length = len(left_array) if self.how == 'left' else np.count_nonzero(available_in_right)

        column_types = {name: dtype for arr in (left_array, right_array)
                        for name, (dtype, _) in arr.dtype.fields.iteritems()}
        result_dtype = [(name, column_types[name]) for name in result_columns]
        new_array = DataFrame.from_structarray(np.empty(resulting_length, dtype=result_dtype))

        if self.how == 'left':
            for column in left_columns:
                new_array[column] = left_array[column]
            for column in right_columns:
                new_array[column][available_in_right] = right_array[column][right_indexes[available_in_right]]
                new_array[column][~available_in_right] = get_null_value(column_types[column])

        elif self.how == 'inner':
            for column in left_columns:
                new_array[column] = left_array[column][available_in_right]
            for column in right_columns:
                new_array[column] = right_array[column][right_indexes[available_in_right]]

        return new_array

    def get_hash_map_keys(self, array):
        """ iterate by tuples of array values from self.columns """
        for row in array:
            yield tuple(row[column] for column in self.join_columns)

    def build_hash_map(self, right_array):
        # TODO: Good place for Cython optimization here ? """
        return {key: idx for idx, key in enumerate(self.get_hash_map_keys(right_array))}

    def get_columns(self, left_array, right_array):
        left_columns = self.left_columns or list(left_array.dtype.names)
        right_columns = self.right_columns or [column for column in right_array.dtype.names
                                               if column not in self.join_columns]
        common_columns = set(left_columns) & set(right_columns)
        if common_columns:
            if self.overwrite:
                logger.debug('Columns %s are represented both in left and right arrays '
                             'and not in join_columns, left array column will be overwritten' % ','.join(common_columns))
                left_columns = list(set(left_columns) - common_columns)
            else:
                raise ValueError('Columns %s are represented both in left '
                                 'and right arrays and not in join_columns' % ','.join(common_columns))
        return left_columns, right_columns
