from __future__ import absolute_import, unicode_literals

import io
import zlib
import bisect
import ctypes
import struct
import ctypes.util
import collections

from ... import enum
from ... import patterns
from ... import itertools

from . import base

import six

LP_c_ubyte = ctypes.POINTER(ctypes.c_ubyte)

ST_SIZE = struct.Struct("<Q")
ST_POINT = struct.Struct("<QQBH")


class ZException(Exception):
    def __init__(self, code):
        if code in Z.Codes:
            code = Z.Codes.val2str(code)
        super(ZException, self).__init__("zlib error: {}".format(code))


class Z(object):
    NO_FLUSH = 0
    BLOCK = 5

    class Codes(enum.Enum):
        OK = 0
        STREAM_END = 1
        NEED_DICT = 2

        ERRNO = -1
        STREAM_ERROR = -2
        DATA_ERROR = -3
        MEM_ERROR = -4
        BUF_ERROR = -5
        VERSION_ERROR = -6

    class Stream(ctypes.Structure):
        _fields_ = [
            ("next_in", LP_c_ubyte),
            ("avail_in", ctypes.c_uint),
            ("total_in", ctypes.c_ulong),
            ("next_out", LP_c_ubyte),
            ("avail_out", ctypes.c_uint),
            ("total_out", ctypes.c_ulong),
            ("msg", ctypes.c_char_p),
            ("state", ctypes.c_void_p),
            ("zalloc", ctypes.c_void_p),
            ("zfree", ctypes.c_void_p),
            ("opaque", ctypes.c_void_p),
            ("data_type", ctypes.c_int),
            ("adler", ctypes.c_ulong),
            ("reserved", ctypes.c_ulong),
        ]

        def __init__(self):
            super(Z.Stream, self).__init__()
            self.zalloc = 0
            self.zfree = 0
            self.opaque = 0
            self.avail_in = 0
            self.avail_out = 0
            self.next_in = ctypes.cast(0, LP_c_ubyte)

    _STREAM_SIZE = ctypes.sizeof(Stream)

    @patterns.singleton_classproperty
    def _libz(self):
        libz = ctypes.CDLL(ctypes.util.find_library("z"))
        libz.zlibVersion.restype = ctypes.c_char_p
        libz.inflateInit2_.argtypes = (ctypes.POINTER(self.Stream), ctypes.c_int, ctypes.c_char_p, ctypes.c_int)
        libz.inflate.argtypes = (ctypes.POINTER(self.Stream), ctypes.c_int)
        libz.inflatePrime.argtypes = (ctypes.POINTER(self.Stream), ctypes.c_int, ctypes.c_int)
        libz.inflateSetDictionary.argtypes = (ctypes.POINTER(self.Stream), LP_c_ubyte, ctypes.c_uint)
        libz.inflateEnd.argtypes = (ctypes.POINTER(self.Stream),)
        return libz

    @patterns.singleton_classproperty
    def version(self):
        return self._libz.zlibVersion()

    @classmethod
    def inflate_init2(cls, strm, window_bits):
        libz = cls._libz
        return libz.inflateInit2_(strm, window_bits, cls.version, cls._STREAM_SIZE)

    @classmethod
    def inflate(cls, strm, flush):
        return cls._libz.inflate(strm, flush)

    @classmethod
    def inflate_end(cls, strm):
        return cls._libz.inflateEnd(strm)

    @classmethod
    def inflate_prime(cls, strm, bits, value):
        return cls._libz.inflatePrime(strm, bits, value)

    @classmethod
    def inflate_set_dictionary(cls, strm, dictionary, dict_length):
        return cls._libz.inflateSetDictionary(strm, dictionary, dict_length)


class Index(base.Index):
    """ Adopted from https://github.com/madler/zlib/blob/master/examples/zran.c """

    compression_type = base.CompressionType.TGZ

    WIN_SIZE = 32768  # 32 KiB
    CHUNK_SIZE = WIN_SIZE >> 1  # WIN_SIZE // 2, 16 KiB
    DEFAULT_SPAN = 1 << 20  # 1 MiB

    AccessPoint = collections.namedtuple("AccessPoint", "offset comp_offset bits window")

    def __init__(self, span=DEFAULT_SPAN):
        self.__span = span
        self.__points = []
        self.__totin = 0
        self.__totout = 0
        self.__last = 0
        self.__strm = Z.Stream()
        self.__window = ctypes.create_string_buffer(b"", self.WIN_SIZE)
        self.__flushed = False

        ret = Z.inflate_init2(self.__strm, 47)  # automatic zlib or gzip decoding
        if ret != Z.Codes.OK:
            raise ZException(ret)

        self.__strm.avail_out = 0

    def __del__(self):
        self.flush()

    def __eq__(self, other):
        return self.__points == other.__points

    def __len__(self):
        return len(self.__points)

    def dump(self, offsets=None):
        points = self.__points
        if offsets is not None:
            used_offsets = set()
            points = []
            for offset in sorted(offsets):
                point = self.get_point(offset)
                if point.offset not in used_offsets:
                    points.append(point)
                    used_offsets.add(point.offset)
        buf = super(Index, self).dump()
        comp = zlib.compressobj(1, zlib.DEFLATED, -15)
        buf.write(comp.compress(ST_SIZE.pack(len(points))))
        for point in points:
            buf.write(comp.compress(ST_POINT.pack(point.offset, point.comp_offset, point.bits, len(point.window))))
            buf.write(comp.compress(point.window))
        buf.write(comp.flush())
        return buf.getvalue()

    @classmethod
    def load(cls, data):
        buf = io.BytesIO(zlib.decompress(data[1:], -15))
        index = cls()
        index.__flushed = True
        total_points = ST_SIZE.unpack(buf.read(ST_SIZE.size))[0]
        for _ in six.moves.xrange(total_points):
            offset, comp_offset, bits, window_size = ST_POINT.unpack(buf.read(ST_POINT.size))
            window = buf.read(window_size)
            index.__points.append(cls.AccessPoint(offset, comp_offset, bits, window))
        return index

    def _rotate_window(self, window=None, left=None, win_size=WIN_SIZE):
        if left is None:
            left = self.__strm.avail_out
        if window is None:
            window = self.__window
        result = b""
        if left:
            result = window.raw[win_size - left:win_size]
        if left < win_size:
            result += window.raw[:win_size - left]
        return result

    def _add_point(self):
        window = self._rotate_window()
        bits = self.__strm.data_type & 7
        self.__points.append(self.AccessPoint(self.__totout, self.__totin - int(bool(bits)), bits, window))

    def build(self, data):
        if self.__flushed:
            raise ZException(Z.Codes.STREAM_END)
        uncompressed = []
        for chunk in itertools.chunker(data, self.CHUNK_SIZE):
            self.__strm.avail_in = len(chunk)
            self.__strm.next_in = ctypes.cast(chunk, LP_c_ubyte)
            while self.__strm.avail_in:
                if self.__strm.avail_out == 0:
                    self.__strm.avail_out = self.WIN_SIZE
                    self.__strm.next_out = ctypes.cast(self.__window, LP_c_ubyte)
                self.__totin += self.__strm.avail_in
                self.__totout += self.__strm.avail_out
                ret = Z.inflate(self.__strm, Z.BLOCK)  # return at end of block
                self.__totin -= self.__strm.avail_in
                self.__totout -= self.__strm.avail_out
                if ret == Z.Codes.NEED_DICT:
                    ret = Z.Codes.DATA_ERROR
                if ret < 0:
                    raise ZException(ret)
                if ret == Z.Codes.STREAM_END:
                    uncompressed.append(self.__window.raw[:self.WIN_SIZE - self.__strm.avail_out])
                    break
                if self.__strm.avail_out == 0:
                    uncompressed.append(self._rotate_window())
                if (
                    self.__strm.data_type & 128 and not self.__strm.data_type & 64 and
                    (self.__totout == 0 or self.__totout - self.__last > self.__span)
                ):
                    self._add_point()
                    self.__last = self.__totout
        return b"".join(uncompressed)

    def flush(self):
        if not self.__flushed:
            Z.inflate_end(self.__strm)
            self.__flushed = True

    def get_point(self, offset):
        return self.__points[bisect.bisect_left(self.__points, self.AccessPoint(offset + 1, 0, 0, "")) - 1]

    def extract(self, fileobj, point, offset, length):
        result = ctypes.create_string_buffer(b"", self.WIN_SIZE)
        discard = ctypes.create_string_buffer(b"", self.WIN_SIZE)
        strm = Z.Stream()
        ret = Z.inflate_init2(strm, -15)  # raw inflate
        try:
            if ret != Z.Codes.OK:
                raise ZException(ret)
            if point.bits:
                Z.inflate_prime(strm, point.bits, ord(fileobj.read(1)) >> (8 - point.bits))
            Z.inflate_set_dictionary(strm, ctypes.cast(point.window, LP_c_ubyte), self.WIN_SIZE)

            # skip uncompressed bytes until offset reached, then satisfy request
            offset -= point.offset
            strm.avail_in = 0
            skip = True
            while skip:  # while skipping to offset
                # define where to put uncompressed data, and how much
                if offset == 0:  # at offset now
                    strm.avail_out = min(length, self.WIN_SIZE)
                    strm.next_out = ctypes.cast(result, LP_c_ubyte)
                    skip = False  # only do this once
                elif offset > self.WIN_SIZE:  # skip WIN_SIZE bytes
                    strm.avail_out = self.WIN_SIZE
                    strm.next_out = ctypes.cast(discard, LP_c_ubyte)
                    offset -= self.WIN_SIZE
                elif offset:  # last skip
                    strm.avail_out = offset
                    strm.next_out = ctypes.cast(discard, LP_c_ubyte)
                    offset = 0

                not_committed = 0
                # uncompress until avail_out filled, or end of stream
                while strm.avail_out:
                    if strm.avail_in == 0:
                        chunk = fileobj.read(self.CHUNK_SIZE)
                        strm.avail_in = len(chunk)
                        if strm.avail_in == 0:
                            raise ZException(Z.Codes.DATA_ERROR)
                        strm.next_in = ctypes.cast(chunk, LP_c_ubyte)
                    avail_out = strm.avail_out
                    ret = Z.inflate(strm, Z.NO_FLUSH)  # normal inflate
                    not_committed += avail_out - strm.avail_out
                    if ret == Z.Codes.NEED_DICT:
                        ret = Z.Codes.DATA_ERROR
                    if ret < 0:
                        raise ZException(ret)
                    if ret == Z.Codes.STREAM_END:
                        yield result.raw[:min(not_committed, length)]
                        break
                    if not skip and not strm.avail_out and length:
                        win_size = min(length, self.WIN_SIZE)
                        yield result.raw[:win_size]
                        strm.avail_out = win_size
                        length -= strm.avail_out
                        strm.next_out = ctypes.cast(result, LP_c_ubyte)
                        not_committed = 0
                # if reach end of stream, then don't keep trying to get more
                if ret == Z.Codes.STREAM_END:
                    break
        finally:
            Z.inflate_end(strm)
