from __future__ import absolute_import

import abc
import time
import types
import Queue
import struct
import socket
import logging
import tarfile
import hashlib
import collections
import threading as th

import requests

from . import rest
from . import proxy
from . import utils
from .types import misc as ctm
from .types import user as ctu
from .types import task as ctt
from .types import database as ctd


class Proto(object):
    """
    This class represents pretty simple data upstream protocol between proxy host and the upload task.
    The whole stream is separated to logical packets. Each packet starts with a header, which contains
    packet type identifier and packet size. All known packet types are listed below in form of inner classes.
    """
    FMT = struct.Struct("!BI")

    class PacketIface(object):
        __slots__ = []
        __metaclass__ = abc.ABCMeta

        def __init__(self, *args, **kwargs):
            super(Proto.PacketIface, self).__init__()
            for i, attr in enumerate(self.__slots__):
                val = None
                if i < len(args):
                    val = args[i]
                elif attr in kwargs:
                    val = kwargs[attr]
                setattr(self, attr, val)

        @abc.abstractmethod
        def __setstate__(self, data):
            pass

        @abc.abstractmethod
        def __getstate__(self):
            pass

        def __str__(self):
            return "{}({})".format(
                self.__class__.__name__,
                ", ".join(["=".join([_, str(getattr(self, _))]) for _ in self.__slots__])
            )

    class Greetings(PacketIface):
        __slots__ = ("addr", "cnonce", "node_id", "request_id")
        FMT = struct.Struct("!I")
        ID = 0

        def __setstate__(self, data):
            self.cnonce = self.FMT.unpack(data[:self.FMT.size])[0]
            self.node_id, self.request_id = data[self.FMT.size:].split(":")
            return self

        def __getstate__(self):
            return "".join((self.FMT.pack(self.cnonce), self.node_id, ":", self.request_id))

    class DataChunk(PacketIface):
        __slots__ = ("data",)
        ID = 1

        def __setstate__(self, data):
            self.data = data
            return self

        def __getstate__(self):
            return self.data

        def __str__(self):
            return "{}({} bytes)".format(self.__class__.__name__, len(self.data))

    class Break(PacketIface):
        __slots__ = tuple()
        ID = 2

        def __setstate__(self, data):
            return self

        def __getstate__(self):
            return ""

    class StreamSummary(PacketIface):
        __slots__ = ("size", "sha1")
        FMT = struct.Struct("!Q")
        ID = 3

        def __setstate__(self, data):
            self.size = self.FMT.unpack(data[:self.FMT.size])[0]
            self.sha1 = data[self.FMT.size:] or None
            return self

        def __getstate__(self):
            return "".join((self.FMT.pack(self.size), self.sha1))

    @classmethod
    def __call__(cls, pkg):
        data = pkg.__getstate__()
        return "".join((cls.FMT.pack(pkg.ID, len(data)), data))


class HTTPHandle(object):
    """ Data uploading via HTTP. """

    #: Amount in seconds to wait for single operation on uploading the data.
    OPERATION_TIMEOUT = 90
    #: Maximum allowed amount of files which can be uploaded as single resource
    MAX_FILES = 10000

    #: Uploading resource metadata:
    #: resource type, architecture, owner, description and `dict` with resource's additional attributes
    ResourceMeta = collections.namedtuple("ResourceMeta", ('type', 'arch', 'owner', 'description', 'attributes'))
    #: Uploading file metadata - file handler or real file path, size and name.
    FileMeta = collections.namedtuple("FileMeta", ('handle', 'size', 'name'))
    #: Uploading stream metadata - file handler, stream size limit and resulting filename.
    StreamMeta = collections.namedtuple("StreamMeta", ('handle', 'size', 'name'))
    Stages = ctm.Upload

    class ReporterStreamer(object):
        WRITE_BUFFER = 0x3FFFFFF  # 64Mb of write buffer

        class DataStream(object):
            Chunk = collections.namedtuple("DataStreamChunk", ("data", "offset"))

            class Cache(object):
                __slots__ = ("size", "queue")

                # Store last 100MiB of the stream for upload resuming if it will be required
                MAX_SIZE = 100 << 20

                def __init__(self):
                    self.size = 0
                    self.queue = collections.deque()

                def append(self, chunk):
                    self.queue.append(chunk)
                    self.size += len(chunk.data)
                    while self.size > self.MAX_SIZE:
                        x = self.queue.popleft()
                        self.size -= len(x.data)

                def rewind(self, offset):
                    res = []
                    for x in reversed(self.queue):
                        res.append(x)
                        if x.offset <= offset:
                            return res
                    return None

            class WriteBuffer(object):
                __slots__ = ("size", "cv", "queue")

                # Store up to 64MiB of the stream of data to upload
                MAX_SIZE = 64 << 20

                def __init__(self):
                    self.size = 0
                    self.cv = th.Condition()
                    self.queue = Queue.Queue()

                def append(self, chunk):
                    with self.cv:
                        while self.size >= self.MAX_SIZE:
                            # "Write buffer of %d bytes exceeds the limit. Waiting for event.",
                            self.cv.wait()
                        self.queue.put(chunk)
                        self.cv.notify_all()
                        if chunk.data:
                            self.size += len(chunk.data)

                def fetch(self):
                    with self.cv:
                        while not self.size and not self.queue.qsize():
                            # "No data in write buffer. Waiting for event."
                            self.cv.wait()
                        chunk = self.queue.get()
                        if chunk.data:
                            self.size -= len(chunk.data)
                        self.cv.notify_all()
                        return chunk

                def interrupt(self, terminator):
                    with self.cv:
                        try:
                            while not self.queue.empty():
                                self.queue.get(block=False)
                            self.size = 0
                        except Queue.Empty:
                            pass
                        self.queue.put(terminator)
                        self.cv.notify_all()

            def __init__(self):
                self._sha1 = hashlib.sha1()
                self._cache = self.Cache()
                self._wbuffer = self.WriteBuffer()
                self._stopping = False
                self._rewound = None
                self._produced = 0
                self._consumed = 0

            def __iter__(self):
                """ Stream data chunk fetcher. """
                while True:
                    if self._rewound:
                        chunk = self._rewound.pop()
                        yield chunk
                        continue

                    chunk = self._wbuffer.fetch()
                    if chunk.data is None:
                        self._stopping = True
                        raise ValueError("Incomplete stream")
                    if not chunk.data:
                        break

                    self._produced += len(chunk.data)
                    self._cache.append(chunk)
                    yield chunk
                self._stopping = True

            def __call__(self, chunk):
                """ Stream data chunk putter. """
                if self._stopping:
                    return
                self._wbuffer.append(self.Chunk(chunk, self._consumed))
                if chunk:
                    self._sha1.update(chunk)
                    self._consumed += len(chunk)

            def rewind(self, offset):
                self._rewound = self._cache.rewind(offset)
                return self._rewound

            def interrupt(self):
                self._wbuffer.interrupt(self.Chunk(None, 0))
                self._stopping = True

            def __int__(self):
                return self._produced

            @property
            def sha1(self):
                return self._sha1.hexdigest()

        def __init__(self, url, session, progress, files, logger):
            self.stopping = False
            self.checksum = None
            self.threads = None
            self.url = url
            self.files = files
            self.logger = logger
            self.session = session
            self.progress = progress
            self.backqueue = Queue.Queue()
            self.datastream = self.DataStream()
            super(HTTPHandle.ReporterStreamer, self).__init__()

        def __iter__(self):
            self.logger.debug("Starting service threads.")
            self.threads = [th.Thread(target=self.producer), th.Thread(target=self.conductor)]
            map(th.Thread.start, self.threads)

            while True:
                progress = self.backqueue.get()
                if not progress:
                    break
                if isinstance(progress, Exception):
                    self.stopping = True
                    map(th.Thread.join, self.threads)
                    raise type(progress)(*progress.args)
                yield progress

            self.logger.debug("Waiting for all threads finish.")
            map(th.Thread.join, self.threads)
            yield self.progress

        def stop(self):
            self.logger.debug("Stopping the data stream machinery.")
            self.stopping = True
            self.datastream(None)

        def producer(self):
            try:
                stream = self.files[0]
                if isinstance(stream, HTTPHandle.StreamMeta):
                    self.logger.info("Producer thread started. Initializing direct stream.")
                    while True:
                        chunk = stream.handle.read(0xFFFFF)
                        if not chunk:
                            break
                        self.write(chunk)
                else:
                    self.logger.info("Producer thread started. Initializing tar stream.")
                    with tarfile.open(mode="w|", bufsize=0xFFFFF, fileobj=self) as tar:
                        for f in self.files:
                            self.logger.debug("Adding %r to the tar stream.", f.name)
                            handle = f.handle
                            if isinstance(handle, basestring):
                                self.logger.debug("Opening file %r.", handle)
                                handle = open(handle)
                            tar.addfile(tar.gettarinfo(arcname=f.name, fileobj=handle), fileobj=handle)
                            if handle != f.handle:
                                handle.close()
                self.logger.debug("Flushing the tar stream.")
                self.datastream("")
            except Exception as ex:
                self.logger.exception("Error while creating the tar stream.")
                self.datastream(None)
                self.backqueue.put(ex)
                self.stopping = True

        def conductor(self):
            try:
                self.logger.info("Conductor thread started. Opening data channel to %r.", self.url)
                r = self.session.put(self.url, data=self.consumer(), timeout=HTTPHandle.OPERATION_TIMEOUT)
                self.logger.debug(
                    "HTTP response code is %r, text: %r, headers: %r",
                    r.status_code, r.text, r.headers
                )
                r.raise_for_status()
                self.checksum = r.text.strip()
                self.backqueue.put(None)
            except (requests.exceptions.ConnectionError, socket.error) as ex:
                self.logger.exception("Error sending data.")
                if self.progress.done:
                    for _ in utils.progressive_yielder(.1, 1, HTTPHandle.OPERATION_TIMEOUT):
                        try:
                            self.logger.info("Reconnecting the data stream.")
                            r = self.session.head(self.url, timeout=HTTPHandle.OPERATION_TIMEOUT)
                            r.raise_for_status()
                            resume_at = int(r.headers[ctm.HTTPHeader.RESUME_AT])
                            if not self.datastream.rewind(resume_at):
                                self.logger.error("Unable to rewind the cache to %d.", resume_at)
                                break
                            self.logger.info(
                                "Rewound the cache to %d (%d chunks to resend).",
                                resume_at, len(self.datastream._rewound)
                            )
                            return self.conductor()
                        except (requests.exceptions.ConnectionError, socket.error):
                            pass
                        except Exception as ex:
                            break
                self.stopping = True
                self.backqueue.put(ex)
                self.datastream.interrupt()
            except Exception as ex:
                self.logger.exception("Fatal error while sending the data stream.")
                if isinstance(ex, requests.HTTPError):
                    ex = requests.HTTPError(
                        "Server {} {{{}}} respond {}: {}".format(
                            ex.response.headers.get(ctm.HTTPHeader.BACKEND_NODE),
                            ex.response.headers.get(ctm.HTTPHeader.INT_REQUEST_ID),
                            str(ex),
                            ex.response.text,
                        ),
                        response=ex.response
                    )
                self.stopping = True
                self.backqueue.put(ex)
                self.datastream.interrupt()

        def write(self, chunk):
            if self.stopping:
                self.datastream(None)
                raise RuntimeError("Threads stop signal")
            self.datastream(chunk)

        def consumer(self):
            ts = 0
            for chunk in self.datastream:
                yield chunk.data
                now = time.time()
                self.progress.done = int(self.datastream)
                if now - ts > .3:
                    self.backqueue.put(self.progress)
                    ts = now
            yield self.datastream.sha1
            self.progress.checksum = self.datastream.sha1

    @staticmethod
    def rest_proxy(base_url, auth=None, total_wait=60):
        srv = rest.Client(base_url + "/api/v1.0" if base_url else None, auth=auth, total_wait=total_wait)
        srv.DEFAULT_TIMEOUT = 5
        srv.reset()
        return srv << rest.Client.HEADERS({
            ctm.HTTPHeader.READ_PREFERENCE: ctd.ReadPreference.val2str(ctd.ReadPreference.PRIMARY)
        })

    def __init__(self, rmeta, auth, url=None, proxy_url=None, total_wait=None, *files):
        """
        Default constructor.

        :param rmeta:       Resource to be created metadata (see :class:`HTTPHandle.ResourceMeta`).
        :type rmeta:        HTTPHandle.ResourceMeta
        :param auth:        OAuth token or :class:`common.proxy.Authentication` instance
                            to use in communication with Sandbox API.
        :param url:         Sandbox base URL to communicate with.
        :param proxy_url:   Sandbox proxy base URL to upload data via.
        :param total_wait:  Time limit in seconds for single API call.
        :param files:       List of tuples (file handler, file size, file name) of files to upload
                            (see :class:`HTTPHandle.FileMeta`).
        """
        if len(files) > self.MAX_FILES:
            raise ValueError("Amount of files to be uploaded ({}) bigger than allowed ({})".format(
                len(files), self.MAX_FILES
            ))
        self._auth = auth
        self.files = files
        self.rmeta = rmeta
        self.url = url.rstrip("/") if url else None
        self.proxy_url = proxy_url.rstrip("/") if proxy_url else None
        self._proxy_session = None
        self._stage = self._check
        self._results = {}
        self.logger = logging.getLogger(__name__)
        # Suppress "no handlers for "common.rest" logger" warning message.
        handlers = logging.getLogger().handlers
        logging.getLogger("common.rest").addHandler(handlers[0] if handlers else logging.NullHandler())
        self._srv = self.rest_proxy(self.url, auth, total_wait)

    def __call__(self, *args, **kwargs):
        while self._stage:
            for state in self._stage():
                self.logger.info("New state %r", state)
                yield state
            self._results[type(state)] = state

    @staticmethod
    def _task_creation_common_params(rmeta, resource_size):
        prio = ctu.DEFAULT_PRIORITY_LIMITS.api
        return dict(
            fail_on_any_error=True,
            notifications=[],
            owner=rmeta.owner,
            description=rmeta.description,
            priority={"class": prio.cls, "subclass": prio.scls},
            # Assume 10Mbps minimum average upload speed + 20% but not less than 5 minutes
            kill_timeout=max(300, int(resource_size * 1.2 * 8 / 10 / 1024 / 1024)),
            requirements={"disk_space": resource_size + (11 << 20)},
        )

    def _task_creation_params(self, resource_size):
        rmeta = self.rmeta

        # Backward-compatibility attributes format parsing
        attrs = (
            rmeta.attributes
            if isinstance(rmeta.attributes, (dict, types.NoneType)) else
            {k.strip(): v.strip() for k, v in (_.split("=") for _ in rmeta.attributes.split(","))}
        )

        context = {
            "version": ctm.Upload.VERSION,
            "stream": ctm.Upload.Stream.PLAIN,
            "amount": resource_size,
        }
        if isinstance(self.files[0], self.StreamMeta):
            context["stream"] = {"name": self.files[0].name, "limit": self.files[0].size}

        params = self._task_creation_common_params(rmeta, resource_size)
        params.update(
            type="HTTP_UPLOAD_2",
            custom_fields=[
                dict(name="resource_arch", value=rmeta.arch),
                dict(name="resource_type", value=rmeta.type),
                dict(name="resource_attrs", value=attrs),
            ],
            context={"upload": context},
        )
        return params

    def _increase_task_priority(self, task_id):
        self.logger.debug("Increasing task #%s priority.", task_id)
        self._proxy_session = requests.Session()
        if self._auth:
            self._proxy_session.auth = rest.Client.Auth(
                proxy.OAuth(self._auth) if isinstance(self._auth, basestring) else self._auth
            )
        for attempt in xrange(-5, 1):
            try:
                self._proxy_session.head(
                    "/".join([self.proxy_url, "upload", str(task_id)]),
                    timeout=self._srv.DEFAULT_TIMEOUT
                )
                break
            except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as ex:
                if not attempt:
                    raise
                self.logger.warning(str(ex))
                time.sleep(3)
            except requests.HTTPError as ex:
                if ex.response.status_code != requests.codes.METHOD_NOT_ALLOWED:
                    raise

    def _get_resource_id(self, task_meta):
        return task_meta["output_parameters"].get("resource")

    def _validate_uploading(self):
        tmeta = self._results[self.Stages.Prepare]
        summary = self._results[self.Stages.DataTransfer]

        ctx = self._srv.task[tmeta.task_id].context[:]
        self.logger.debug("Task #%s upload context is %r", tmeta.task_id, ctx["upload"])
        if int(ctx["upload"]["received"]) != summary.done:
            raise AssertionError("Task received stream of {} bytes, while script sent {} bytes".format(
                ctx["upload"]["received"], summary.done
            ))
        if ctx["upload"]["checksum"] != summary.checksum:
            raise AssertionError("Task calculated stream SHA1 {}, while script's one is {}".format(
                ctx["upload"]["checksum"], summary.checksum
            ))

    def _check(self):
        r = self.Stages.Check()
        r.amount = len(self.files)
        self.logger.debug("Checking %d files.", r.amount)
        r.size = sum(f.size for f in self.files)
        yield r
        self._stage = self._prepare

    def _prepare(self):
        r = self.Stages.Prepare()
        yield r

        self.logger.debug("Creating a new task.")
        size = self._results[self.Stages.Check].size
        task_id = self._srv.task(**self._task_creation_params(size))["id"]
        r.task_id = task_id
        yield r

        self.logger.debug("Enqueuing task #%s.", task_id)
        res = next(
            iter(self._srv.batch.tasks.start.update([r.task_id])),
            {"status": "UNKNOWN", "message": "Unknown status returned"}
        )
        if res["status"] != ctm.BatchResultStatus.SUCCESS:
            self.logger.warning(
                "Task start operation finished with status %s, message: %s", res["status"], res["message"]
            )
            resp = self._srv.task[task_id].read()
            if resp["status"] not in utils.chain(ctt.Status.Group.QUEUE, ctt.Status.Group.EXECUTE):
                raise RuntimeError("Cannot queue task #{}, its current status is {}".format(task_id, resp["status"]))

        self._increase_task_priority(task_id)

        _TS = ctt.Status
        task_status = None
        self.logger.debug("Waiting for task #%s execution start and new resource registration.", task_id)
        for _, _ in utils.progressive_yielder(1, 1, self.OPERATION_TIMEOUT * 2):
            task_meta = self._srv.task[r.task_id][:]
            if task_status != task_meta["status"]:
                task_status = task_meta["status"]
                self.logger.debug("Task status changed to %r", task_status)
            if task_status not in list(utils.chain(_TS.Group.QUEUE, _TS.ASSIGNED, _TS.PREPARING, _TS.STOPPING)):
                raise AssertionError("Task #{} is in wrong '{}' state.".format(task_id, task_status))
            r.resource_id = self._get_resource_id(task_meta)
            if r.resource_id:
                break
        self.logger.debug("Registered resource #%s.", r.resource_id)
        yield r

        self._stage = self._upload

    def _upload(self):
        progress = self.Stages.DataTransfer()
        progress.total = self._results[self.Stages.Check].size
        task_id = self._results[self.Stages.Prepare].task_id
        yield progress

        streamer = self.ReporterStreamer(
            "/".join([self.proxy_url, "upload", str(task_id)]),
            self._proxy_session, progress, self.files, self.logger
        )
        try:
            for progress in streamer:
                if isinstance(progress, Exception):
                    raise progress
                yield progress
        finally:
            streamer.stop()
        self._stage = self._share

    def _share(self):
        r = self.Stages.Share()
        _TT = ctt.Status
        tmeta = self._results[self.Stages.Prepare]
        correct_statuses = [_TT.PREPARING, _TT.EXECUTING, _TT.FINISHING]

        self.logger.debug("Waiting for task #%s finish.", tmeta.task_id)
        for slept, _ in utils.progressive_yielder(1, 3, 300, False):
            r.task_state = self._srv.task[tmeta.task_id][:]["status"]
            yield r
            if r.task_state == _TT.SUCCESS:
                break
            if r.task_state not in correct_statuses:
                raise AssertionError("Task #{} switched to incorrect state '{}' (see task's logs for details).".format(
                    tmeta.task_id, r.task_state
                ))

        self._validate_uploading()

        res = self._srv.resource[tmeta.resource_id][:]
        r.skynet_id = res["skynet_id"]
        r.md5sum = res["md5"]
        r.meta = res
        yield r

        self._stage = None


class SkynetHandle(HTTPHandle):
    """ Data uploading via skynet. """

    def __init__(self, rmeta, auth, skynet_id, url=None, total_wait=None, *files):
        super(SkynetHandle, self).__init__(rmeta, auth, url, None, total_wait, *files)
        self.skynet_id = skynet_id

    def _task_creation_params(self, resource_size):
        rmeta = self.rmeta
        attrs = (
            rmeta.attributes
            if isinstance(rmeta.attributes, basestring) else
            ",".join("{}={}".format(*_) for _ in rmeta.attributes.iteritems())
        )

        context = {
            "version": ctm.Upload.VERSION,
            "amount": resource_size,
        }

        params = self._task_creation_common_params(rmeta, resource_size)
        params.update(
            type="REMOTE_COPY_RESOURCE",
            custom_fields=[
                dict(name=n, value=v)
                for n, v in (
                    ("resource_type", rmeta.type),
                    ("resource_arch", rmeta.arch),
                    ("created_resource_name", self.files[0].name),
                    ("remote_file_name", self.skynet_id),
                    ("remote_file_protocol", "skynet"),
                    ("resource_attrs", attrs),
                )
            ],
            context={"upload": context},
        )
        return params

    def _increase_task_priority(self, task_id):
        pass

    def _get_resource_id(self, task_meta):
        ctx = self._srv.task[task_meta["id"]].context[:]
        return ctx.get("result_resource_id")

    def _validate_uploading(self):
        pass

    def _upload(self):
        task_id = self._results[self.Stages.Prepare].task_id

        self.logger.debug("Waiting for task #%s start.", task_id)
        target_statuses = (ctt.Status.PREPARING, ctt.Status.EXECUTING, ctt.Status.FINISHING)
        statuses_to_skip = list(utils.chain(ctt.Status.Group.QUEUE, ctt.Status.ASSIGNED))

        for slept, _ in utils.progressive_yielder(1, 3, 300, False):
            status = self._srv.task[task_id][:]["status"]
            if status in target_statuses:
                break
            if status not in statuses_to_skip:
                raise AssertionError("Task #{} switched to incorrect state '{}'".format(task_id, status))

        progress = self.Stages.DataTransfer()
        progress.total = progress.done = self._results[self.Stages.Check].size
        yield progress
        self._stage = self._share
