import functools
from multiprocessing.shared_memory import SharedMemory
import types

import numpy


class SupportedType:
    def __init__(self, python_type):
        self.python_type = python_type

    def encode(self, value, context):
        raise NotImplementedError

    def decode(self, metadata, context):
        raise NotImplementedError


class FixedSizeType(SupportedType):
    def __init__(self, python_type, dtype):
        super().__init__(python_type)
        self.dtype = dtype

    def encode(self, value, context):
        context.py_data[self.python_type].append(value)
        return len(context.py_data[self.python_type]) - 1

    def decode(self, metadata, context):
        return context.numpy_data[self.python_type][metadata]


class InplaceType(SupportedType):
    def __init__(self, python_type, encoding_func, decoding_func):
        super().__init__(python_type)
        self.encoding_func = encoding_func
        self.decoding_func = decoding_func

    def encode(self, value, _):
        return self.encoding_func(value)

    def decode(self, metadata, _):
        return self.decoding_func(metadata)


class VariableSizeType(SupportedType):
    def __init__(self, python_type, encoding_func, decoding_func):
        super().__init__(python_type)
        self.encoding_func = encoding_func
        self.decoding_func = decoding_func

    def encode(self, value, context):
        context.buffer_items.append(self.encoding_func(value))
        return len(context.buffer_items) - 1

    def decode(self, metadata, context):
        index = int(metadata)
        pos = context.buffer_index[index]
        size = context.buffer_index[index + 1] - pos
        return self.decoding_func(context.buffer[pos:pos+size].tobytes())


class TypesInfo:
    wrapper_types = (
        FixedSizeType(int, dtype=numpy.uint64),
        FixedSizeType(float, dtype=numpy.float64),
        VariableSizeType(bytes, encoding_func=lambda x: x, decoding_func=lambda x: x),
        VariableSizeType(str, encoding_func=lambda x: x.encode(), decoding_func=lambda x: x.decode()),
        InplaceType(bool, encoding_func=lambda x: int(x), decoding_func=lambda x: bool(x)),
        InplaceType(None.__class__, encoding_func=lambda _: 0, decoding_func=lambda _: None),
    )
    python_type_to_wrapper_type_code = {
        wrapper_type.python_type: numpy.uint64(i) for i, wrapper_type in enumerate(wrapper_types)
    }
    python_type_to_wrapper_type = {
        wrapper_type.python_type: wrapper_type for wrapper_type in wrapper_types
    }

    @classmethod
    def get_wrapper_type(cls, python_type):
        return cls.python_type_to_wrapper_type[python_type]

    @classmethod
    def get_wrapper_type_code(cls, python_type):
        return cls.python_type_to_wrapper_type_code[python_type]

    @classmethod
    def get_number_of_types(cls):
        return len(cls.wrapper_types)

    @classmethod
    def get_fixed_size_types(cls):
        return (wrapper_type for wrapper_type in cls.wrapper_types if isinstance(wrapper_type, FixedSizeType))


class MetaEncoder:
    wrapper_type_code_shift = numpy.uint64(48)
    value_mask = (numpy.uint64(1) << wrapper_type_code_shift) - numpy.uint64(1)

    @classmethod
    def encode(cls, value, python_type):
        return value | int(TypesInfo.get_wrapper_type_code(python_type) << cls.wrapper_type_code_shift)

    @classmethod
    def decode(cls, meta):
        return meta & cls.value_mask, (TypesInfo.wrapper_types[meta >> cls.wrapper_type_code_shift])


HEADER_SIZE = 2 + TypesInfo.get_number_of_types()


class EncodingContext:
    def __init__(self, sequence):
        self.py_data = {wrapper_type.python_type: [] for wrapper_type in TypesInfo.wrapper_types}
        self.buffer_items = []
        self.meta = numpy.ndarray((len(sequence),), dtype=numpy.uint64)

        for i, item in enumerate(sequence):
            self.encode_item(i, item)

    def encode_item(self, index, item):
        wrapper_type = TypesInfo.get_wrapper_type(type(item))
        metadata = wrapper_type.encode(item, self)
        self.meta[index] = MetaEncoder.encode(metadata, type(item))

    def get_fixed_size_data(self):
        return {
            wrapper_type.python_type: numpy.array(self.py_data[wrapper_type.python_type], dtype=wrapper_type.dtype)
            for wrapper_type in TypesInfo.get_fixed_size_types()
        }

    def get_buffer_index(self):
        buffer_index = numpy.ndarray((len(self.buffer_items) + 1,), dtype=numpy.uint64)
        buffer_index[0] = 0
        for i, item in enumerate(self.buffer_items):
            buffer_index[i + 1] = buffer_index[i] + len(item)
        return buffer_index

    def get_buffer_items(self):
        return self.buffer_items

    def get_header(self, buffer_index, fixed_size_data):
        header = numpy.ndarray((HEADER_SIZE,), dtype=numpy.uint64)
        header[0] = len(self.meta)
        header[1] = len(buffer_index)

        for i, wrapper_type in enumerate(TypesInfo.get_fixed_size_types()):
            header[i + 2] = fixed_size_data[wrapper_type.python_type].size

        return header


class DecodingContext:
    def __init__(self, meta, numpy_data, buffer_index, buffer):
        self.meta = meta
        self.numpy_data = numpy_data
        self.buffer_index = buffer_index
        self.buffer = buffer

    def decode(self, position):
        metadata, wrapper_type = MetaEncoder.decode(self.meta[position])
        return wrapper_type.decode(metadata, self)


class ImmutableShareableList:
    """
    Shared memory layout: HEADER|META|FIXED_SIZE_DATA|BUFFER_INDEX|BUFFER
    HEADER: array<uint64>[2 + number of data types]
    META: array<uint64>[list size]
    FIXED_SIZE_DATA: array<fixed_size_type>[number of elements with this type] * number of fixed size types
    BUFFER_INDEX: array<uint64>[number of elements with variable size types(all of them)]
    BUFFER: sequence of encoded elements with variable size types
    """
    def __init__(self, sequence=None, name=None):
        if name is None or sequence is not None:
            encoding_context = EncodingContext(sequence or ())
            self.write_to_shared_memory(name, encoding_context)
        else:
            self.read_from_schared_memory(name)

    def write_to_shared_memory(self, name, encoding_context):
        fixed_size_data = encoding_context.get_fixed_size_data()
        buffer_index = encoding_context.get_buffer_index()
        buffer_items = encoding_context.get_buffer_items()
        header = encoding_context.get_header(buffer_index, fixed_size_data)

        requested_size = header.nbytes + encoding_context.meta.nbytes + sum(arr.nbytes for arr in fixed_size_data.values()) + buffer_index.nbytes + int(buffer_index[-1])

        self.shm = SharedMemory(name, create=True, size=requested_size)

        shared_header = numpy.ndarray(header.shape, dtype=header.dtype, buffer=self.shm.buf)
        shared_header[:] = header[:]
        meta = numpy.ndarray(encoding_context.meta.shape, dtype=encoding_context.meta.dtype, buffer=self.shm.buf[shared_header.nbytes:])
        meta[:] = encoding_context.meta[:]

        offset = shared_header.nbytes + meta.nbytes

        shared_data = {}

        def add_to_shared(array, offset):
            shared_array = numpy.ndarray(array.shape, dtype=array.dtype, buffer=self.shm.buf, offset=offset)
            shared_array[:] = array[:]
            return shared_array, offset + shared_array.nbytes

        for wrapper_type in TypesInfo.get_fixed_size_types():
            shared_data[wrapper_type.python_type], offset = add_to_shared(fixed_size_data[wrapper_type.python_type], offset)

        shared_buffer_index, offset = add_to_shared(buffer_index, offset)
        shared_buffer = self.shm.buf[offset:]
        for item in buffer_items:
            self.shm.buf[offset:offset+len(item)] = item
            offset += len(item)

        self._decoding_context = DecodingContext(meta, shared_data, shared_buffer_index, shared_buffer)

    def read_from_schared_memory(self, name):
        self.shm = SharedMemory(name)

        header = numpy.ndarray((HEADER_SIZE,), dtype=numpy.uint64, buffer=self.shm.buf)
        meta = numpy.ndarray((header[0],), dtype=numpy.uint64, buffer=self.shm.buf, offset=header.nbytes)
        offset = header.nbytes + meta.nbytes
        shared_data = {}

        for i, wrapper_type in enumerate(TypesInfo.get_fixed_size_types()):
            shared_data[wrapper_type.python_type] = numpy.ndarray((header[2 + i],), dtype=wrapper_type.dtype, buffer=self.shm.buf, offset=offset)
            offset += shared_data[wrapper_type.python_type].nbytes

        shared_buffer_index = numpy.ndarray(header[1], dtype=numpy.uint64, buffer=self.shm.buf, offset=offset)
        offset += shared_buffer_index.nbytes
        shared_buffer = self.shm.buf[offset:]
        self._decoding_context = DecodingContext(meta, shared_data, shared_buffer_index, shared_buffer)

    def __getitem__(self, position):
        return self._decoding_context.decode(position)

    def __reduce__(self):
        return functools.partial(self.__class__, name=self.shm.name), ()

    def __len__(self):
        return self._decoding_context.meta.size

    def __repr__(self):
        return f'{self.__class__.__name__}({list(self)}, name={self.shm.name!r})'

    def close(self, owning=False):
        del self._decoding_context
        self.shm.close()
        if owning:
            self.shm.unlink()

    __class_getitem__ = classmethod(types.GenericAlias)
