import io
import os
import time
import uuid
import zlib
import struct
import logging
import tarfile
import calendar
import requests
import threading as th

import six
import flask
import queue
import pathlib

from sandbox.common import mds as common_mds
from sandbox.common.mds import stream as mds_stream
from sandbox.common import patterns as common_patterns
from sandbox.common.types import misc as ctm
from sandbox.common.types import resource as ctr


class FlaskResponseFromMDS(flask.Response):
    automatically_set_content_length = False


class IterableStream:

    def __init__(
        self,
        response: requests.Response,
        logger: logging.Logger,
        is_final_response: bool = True,
        chunk_size: int = None,
    ) -> None:
        self._response = response
        self._logger = logger
        self._is_final_response = is_final_response
        self._chunk_size = chunk_size or mds_stream.DEFAULT_CHUNK_SIZE

    def __iter__(self):
        transferred = 0
        try:
            for chunk in self._iter_impl():
                transferred += len(chunk)
                yield chunk
        finally:
            if self._is_final_response:
                rs = flask.g.get("request_statistics")
                if rs:
                    flask.g.request_statistics.provide_data_size(transferred)
                else:
                    self._logger.warning("Impossible to handle request: request statistics aren't initialized")

    def _iter_impl(self):
        return self._response.iter_content(self._chunk_size)


class ReliableStreamFromMDS(IterableStream):
    """
    Class provides reliable stream of data downloaded using single S3-MDS link.
    """

    def __init__(self, response, timeout, logger, is_final_response=True, chunk_size=None):
        super(ReliableStreamFromMDS, self).__init__(response, logger, is_final_response, chunk_size=chunk_size)
        self._timeout = timeout

    def _iter_impl(self):
        start = time.time()
        size = 0
        try:
            for chunk in mds_stream.iter_content(self._response, logger=self._logger, chunk_size=self._chunk_size):
                size += len(chunk)
                yield chunk
        finally:
            rs = flask.g.get("request_statistics")
            if rs:
                rs.on_request_to_s3(self._response, size, time.time() - start)
            else:
                self._logger.warning("Impossible to handle request to S3: request statistics aren't initialized")


class ReliableTarStream(IterableStream):

    def __init__(
        self, metadata, namespace, mds_settings, req_headers_func, timeout, relpath, logger, response, chunk_size=None
    ):
        self._metadata = metadata
        self._tar_key = None
        self._tar_dir = self._metadata[0]["type"] == ctr.FileType.TARDIR
        if self._tar_dir:
            self._tar_key = self._metadata[0]["key"]
            self._metadata = self._metadata[2:]
        self._namespace = namespace
        self._mds_settings = mds_settings
        self._req_headers_func = req_headers_func
        self._timeout = timeout
        self._relpath = relpath and pathlib.Path(relpath)
        self._last_modified = response.headers.get(ctm.HTTPHeader.LAST_MODIFIED, 0)
        if self._last_modified:
            # noinspection PyTypeChecker
            self._last_modified = int(calendar.timegm(time.strptime(self._last_modified, "%a, %d %b %Y %H:%M:%S %Z")))
        super(ReliableTarStream, self).__init__(response, logger, True, chunk_size=chunk_size)

    def _add_tarinfo(self, buf, item):
        item_path = pathlib.Path(item["path"])
        if self._relpath:
            try:
                item_path = self._relpath.name / item_path.relative_to(self._relpath)
            except ValueError:
                return False
        ti = tarfile.TarInfo(name=str(item_path))
        ti.mtime = self._last_modified
        item_type = item["type"]
        if item_type == ctr.FileType.DIR:
            ti.type = tarfile.DIRTYPE
            ti.mode = 0o755
        elif item_type == ctr.FileType.SYMLINK:
            ti.type = tarfile.SYMTYPE
            ti.mode = 0o644
            ti.linkname = item["symlink"]
        elif item_type in (ctr.FileType.TOUCH, ctr.FileType.FILE):
            ti.type = tarfile.REGTYPE
            ti.mode = 0o755 if item.get("executable") else 0o644
            ti.size = 0 if item_type == ctr.FileType.TOUCH else item["size"]
        buf.write(ti.tobuf())
        return True

    def _iter_impl(self):
        buf = io.BytesIO()
        dirs = set()

        for item in self._metadata:
            path_parts = pathlib.Path(item["path"]).parts
            item_type = item["type"]
            is_dir = item_type == ctr.FileType.DIR
            dir_path = ""
            for part in path_parts if is_dir else path_parts[:-1]:
                dir_path += "/{}".format(part) if dir_path else part
                if dir_path in dirs:
                    continue
                dirs.add(dir_path)
                self._add_tarinfo(buf, dict(type=ctr.FileType.DIR, path=dir_path))
            if is_dir or not self._add_tarinfo(buf, item) or item_type != ctr.FileType.FILE:
                continue
            yield buf.getvalue()
            buf = io.BytesIO()
            if self._tar_dir:
                link = common_mds.s3_link(self._tar_key, self._namespace, mds_settings=self._mds_settings)
            else:
                link = common_mds.s3_link(item["key"], self._namespace, mds_settings=self._mds_settings)
            headers = self._req_headers_func(link)
            if self._tar_dir:
                offset = item["offset"]
                size = item["size"]
                headers["Range"] = "bytes={}-{}".format(offset, offset + size - 1)
            resp = requests.get(
                link,
                stream=True,
                headers=headers,
                timeout=self._timeout
            )
            remainder = item["size"] % tarfile.BLOCKSIZE
            for chunk in ReliableStreamFromMDS(resp, self._timeout, self._logger, False, chunk_size=self._chunk_size):
                yield chunk
            if remainder:
                yield b"\0" * (tarfile.BLOCKSIZE - remainder)
        if buf.getvalue():
            yield buf.getvalue()


class ReliableTgzStream(ReliableTarStream):

    def _iter_impl(self):
        pipe = Gzipper().pipe
        for data in super(ReliableTgzStream, self)._iter_impl():
            pipe.push(data)
            for chunk in pipe:
                yield chunk
        for chunk in pipe.flush():
            yield chunk


class PerProcessMeta(common_patterns.ThreadSafeSingletonMeta):
    @property
    def tag(cls):
        return "_{}_{}".format(os.getpid(), super(PerProcessMeta, cls).tag)


@six.add_metaclass(PerProcessMeta)
class Gzipper(object):

    _gzip_header = b"\x1f\x8b\x08\x00\x00\x00\x00\x00\x02\xff"

    class Compressor(object):
        def __init__(self):
            # noinspection PyArgumentList
            self.__compressor = zlib.compressobj(6, zlib.DEFLATED, -zlib.MAX_WBITS, zlib.DEF_MEM_LEVEL, 0)
            self.__size = 0
            self.__crc = zlib.crc32(b"") & 0xffffffff

        def push(self, data):
            self.__size += len(data)
            self.__crc = zlib.crc32(data, self.__crc) & 0xffffffff
            ret = self.__compressor.compress(data)
            return ret

        def flush(self):
            return self.__compressor.flush() + struct.pack("<2L", self.__crc, self.__size & 0xffffffff)

    class Pipe(object):
        def __init__(self, pipe_id, input_queue, output_queue):
            self.__pipe_id = pipe_id
            self.__input_queue = input_queue
            self.__output_queue = output_queue

        def __iter__(self):
            while not self.__output_queue.empty():
                yield self.__output_queue.get()

        def push(self, chunk):
            self.__input_queue.put((chunk, self.__pipe_id))

        def flush(self):
            self.push(None)
            while True:
                chunk = self.__output_queue.get()
                if chunk is None:
                    break
                yield chunk

    def __init__(self):
        self.__input_queue = queue.Queue()
        self.__compressors = {}
        self.__output_queues = {}
        self.__thread = th.Thread(target=self.__loop)
        self.__thread.daemon = True
        self.__thread.start()

    @property
    def pipe(self):
        pipe_id = uuid.uuid4().hex
        self.__compressors[pipe_id] = self.Compressor()
        output_queue = self.__output_queues[pipe_id] = queue.Queue()
        output_queue.put(self._gzip_header)
        return self.Pipe(pipe_id, self.__input_queue, output_queue)

    def __loop(self):
        while True:
            data, pipe_id = self.__input_queue.get()
            compressor = self.__compressors.get(pipe_id)
            output_queue = self.__output_queues.get(pipe_id)
            if compressor is None or output_queue is None:
                continue
            if data is None:
                output_queue.put(compressor.flush())
                output_queue.put(None)
                self.__compressors.pop(pipe_id)
                self.__output_queues.pop(pipe_id)
                continue
            chunk = compressor.push(data)
            if chunk:
                output_queue.put(chunk)
