"""
Static (doesn't support key modifications after creation) numpy-array
based hashtable with an interface of Python's `dict`.
Uses CPMH (http://cmph.sourceforge.net/) perfect hash functions under the hood,
thus keeping it's structure simple (number of buckets is equal to number of keys).
"""

import cPickle as pickle
import json
import logging
import ujson

import numpy as np
from jafar.utils.cmph_wrapper import generate_hash

logger = logging.getLogger(__name__)


class DictLike(object):
    """ Shortcuts for making dict-like classes """

    def get(self, item, default=None):
        try:
            return self[item]
        except KeyError:
            return default

    def __iter__(self):
        return self.iterkeys()

    def items(self):
        return zip(self.keys(), self.values())

    def iteritems(self):
        return iter(self.items())

    def iterkeys(self):
        return iter(self.keys())

    def itervalues(self):
        return iter(self.values())


class StaticMapping(DictLike):
    """
    Keys should be np.ndarray type with dtype=str.
    """

    def __init__(self, array, hash_function):
        self.array = array
        self.hash_function = hash_function

    @property
    def keys_dtype(self):
        return self.array.dtype

    @property
    def values_dtype(self):
        return np.int32

    @classmethod
    def create(cls, array, return_index=False):
        assert isinstance(array, np.ndarray) and len(array.shape) == 1, \
            "Only 1D np.array is supported in StaticMapping"
        array, index = np.unique(array, return_index=True)

        # generating Hash
        str_array = map(str, array)
        hash_function = generate_hash(str_array)
        mapped_array = hash_function.map(str_array)
        permutation = np.argsort(mapped_array)
        static_mapping = cls(array[permutation], hash_function)
        if return_index:
            # Returns index that is some permutation, that being applied to initial array makes array mapped in
            # increasing sequence of integers
            # i.e. mapping.map(array[index])=range(len(p.unique(array))
            return static_mapping, index[permutation]
        return static_mapping

    def _key_is_valid(self, key):
        return (0 <= key) & (key < len(self.array))

    def _is_collision(self, key, value):
        return self.array[key] != value

    def _get_key(self, value):
        return self.hash_function(value)

    @property
    def reverse(self):
        return ReverseMapping(self)

    def map(self, values, default):
        keys = self.hash_function.map(values) % len(self.array)
        collisions = values != self.array[keys]
        keys[collisions] = default
        return keys

    def __contains__(self, item):
        key = self._get_key(item)
        return self._key_is_valid(key) and not self._is_collision(key, item)

    def __getitem__(self, item):
        key = self._get_key(item)
        if not self._key_is_valid(key) or self._is_collision(key, item):
            raise KeyError(item)
        return key

    def __len__(self):
        return len(self.array)

    def keys(self):
        return list(self.array)

    def values(self):
        return range(len(self.array))


# noinspection PyProtectedMember
class ReverseMapping(DictLike):
    """ mapping.reverse[key] and mapping.reverse.map(array) syntax support """

    def __init__(self, mapping):
        self.mapping = mapping

    def map(self, keys, default):
        result = np.full(len(keys), default, dtype=self.mapping.array.dtype)
        valid_keys = np.array(map(self.mapping._key_is_valid, keys), dtype=np.bool)
        result[valid_keys] = self.mapping.array[keys[valid_keys]]
        return result

    @property
    def keys_dtype(self):
        return np.int32

    @property
    def values_dtype(self):
        return self.mapping.array.dtype

    @property
    def reverse(self):
        return self.mapping

    def __getitem__(self, item):
        if not self.mapping._key_is_valid(item):
            raise KeyError(item)
        return self.mapping.array[item]

    def __contains__(self, item):
        return self._key_is_valid(item)

    def __len__(self):
        return len(self.mapping)

    def keys(self):
        return self.mapping.values()

    def values(self):
        return self.mapping.keys()


class StaticDict(DictLike):
    valid_serializations = (None, 'json', 'pickle')

    @property
    def keys_dtype(self):
        return self.keys_array.dtype

    @property
    def values_dtype(self):
        if self.serialization:
            return np.object
        else:
            return self.values_array.dtype

    @staticmethod
    def check_type_consistency(array):
        if len(array) > 0:
            element_type = type(array[0])
            return all(isinstance(element, element_type) for element in array)
        return True

    @staticmethod
    def serialize_array(array):
        try:
            return 'json', np.array(map(json.dumps, array))
        except (ValueError, TypeError):
            logger.warn("Couldn't serialize json, gonna use pickle.")
            return 'pickle', np.array(map(pickle.dumps, array))

    def deserialize(self, value):
        if self.serialization is None:
            return value
        if self.serialization == 'json':
            return ujson.loads(value)
        elif self.serialization == 'pickle':
            return pickle.loads(value)

    def deserialize_array(self, array):
        if self.serialization is None:
            return array
        return np.array(map(self.deserialize, array), dtype=np.object)

    @classmethod
    def preprocess_values(cls, values):
        if not cls.check_type_consistency(values):
            values = np.array(values, dtype=np.object)
        else:
            # try to cast objects to numpy type
            values = np.array(values)
        if values.dtype == np.object or len(
                values.shape) != 1:  # numpy may try to make a 2d array from list of lists
            logger.debug("Couldn't adapt dictionary values to native NumPy dtype: will try to serialize")
            serialization, values = cls.serialize_array(values)
        else:
            serialization = None
        return serialization, values

    @classmethod
    def create(cls, keys, values):
        assert len(keys) == len(values), 'Keys and values must have the same length'
        assert cls.check_type_consistency(keys), 'Keys must be the same type'
        str_keys = np.array(keys, dtype=np.str)
        keys = np.array(keys)
        if keys.dtype == np.object:
            raise TypeError(
                "Couldn't adapt dictionary keys to native NumPy dtype: "
                "only number/string keys are supported"
            )
        serialization, values = cls.preprocess_values(values)
        key_map, permutation = StaticMapping.create(str_keys, return_index=True)
        values = values[permutation]
        keys = keys[permutation]
        return cls(key_map, keys, values, serialization)

    def __init__(self, key_map, keys, values, serialization):
        assert serialization in self.valid_serializations, \
            'Bad serialization %s, valid choices are: %s' % (serialization, self.valid_serializations)
        self.key_map = key_map
        self.keys_array = keys
        self.values_array = values
        self.serialization = serialization

    def __getitem__(self, item):
        bucket = self.key_map[str(item)]
        return self.deserialize(self.values_array[bucket])

    def __contains__(self, item):
        return str(item) in self.key_map

    def keys(self):
        return self.keys_array

    def values(self):
        return self.deserialize_array(self.values_array)

    def __len__(self):
        return len(self.values_array)
