import os
import abc
import time
import errno
import socket
import random
import logging
import inspect
import hashlib
import tarfile
from six.moves.urllib.parse import urlparse
import contextlib

from sandbox import common
import sandbox.common.types.misc as ctm
import sandbox.common.types.client as ctc

from sandbox import sdk2

from sandbox.projects.common.nanny import nanny
from sandbox.projects.common.ya_deploy import release_integration


class SocketHandler(object):
    """  Socket handler for input data processing. Declared outside of main class to avoid pickling problems. """
    __metaclass__ = common.utils.SingletonMeta

    # Timeout on any operation on socket.
    TIMEOUT = 60

    Proto = common.upload.Proto

    def __init__(self):
        self.pid2cls = {
            _.ID: _
            for _ in self.Proto.__dict__.itervalues()
            if inspect.isclass(_) and issubclass(_, self.Proto.PacketIface) and _ is not self.Proto.PacketIface
        }

        url = urlparse(common.config.Registry().client.rest_url)
        addr = socket.getaddrinfo(url.hostname, url.port or 80)[0]  # If IPv6 address available, it should be first
        with contextlib.closing(socket.socket(addr[0])) as sock:
            logging.debug("Determining own address by connecting to %r", addr[4][:2])
            for _ in xrange(-150, 1):
                try:
                    sock.connect(addr[4][:2])
                    break
                except socket.error as ex:
                    if not _ or ex.errno not in (errno.ECONNRESET, errno.ECONNABORTED, errno.ECONNREFUSED):
                        raise
                    time.sleep(0.1)
            self._addr = sock.getsockname()[0]
        self.sock = socket.socket(socket.AF_INET6)
        self.sock.settimeout(self.TIMEOUT)
        self.sock.bind(("", 0))
        self.sock.listen(1)
        self._gen = None

    @property
    def addr(self):
        return self._addr, self.sock.getsockname()[1]

    def parser(self):
        peer, addr = self.sock.accept()
        peer.settimeout(self.TIMEOUT)
        logging.debug("Peer socket accepted.")

        def reliable_read(amount):
            data, chunk = "", True
            while len(data) < amount:
                for _ in xrange(30):
                    try:
                        chunk = peer.recv(amount - len(data))
                        if not chunk:
                            return chunk
                        data = chunk if not data else data + chunk
                        break
                    except socket.timeout:
                        return ""
                    except socket.error as ex:
                        if ex.errno != errno.EAGAIN or _ == 29:
                            raise
                        time.sleep(.1)
            return data

        greets = None
        pr = self.Proto()
        while True:
            hdr = reliable_read(pr.FMT.size)
            if not hdr:
                break
            pid, sz = pr.FMT.unpack(hdr)
            data = reliable_read(sz)
            if sz and not data:
                break
            obj = self.pid2cls[pid]().__setstate__(data)
            # logging.debug("Received a new packet %s of %d bytes length.", obj, sz)  # DEBUG: DEBUG
            if not greets:
                assert isinstance(obj, self.Proto.Greetings)
                obj.addr = addr
                greets = obj
            elif isinstance(obj, pr.Break):
                raise common.errors.TaskStop("Stream interrupted")
            yield obj
        logging.debug("Peer socket disconnected.")
        self._gen = None

    def __iter__(self):
        if not self._gen:
            self._gen = self.parser()
        for _ in self._gen:
            yield _


class ABCStreamReader(object):
    __metaclass__ = abc.ABCMeta

    def __init__(self, task, expected, reader):
        self.task = task
        self.expected = expected
        self.reader = reader

        self.sha1 = hashlib.sha1()
        self.common_prefix = ""
        self.finished = False
        self.progress = 0
        self.received = 0
        self.buffer = ""

    @abc.abstractmethod
    def extractall(self, dst):
        pass

    def read(self, amount):
        if self.buffer and amount <= len(self.buffer):
            chunk, self.buffer = self.buffer[:amount], self.buffer[amount:]
        elif self.finished:
            chunk, self.buffer = self.buffer, ""
        else:
            try:
                pkg = self.reader.next()
            except StopIteration:
                logging.warning("Stream interrupted at %d bytes. Waiting for upload resume.", self.received)
                self.task.update_stream_status(self.received)
                sh = SocketHandler()
                self.reader = iter(sh)

                self.task.handle_greetings(self.reader.next())
                pkg = self.reader.next()

            self.received += len(pkg.data)
            assert isinstance(pkg, common.upload.Proto.DataChunk)

            curp = self.received * 100 / self.expected if self.expected else 100
            if curp != self.progress:
                logging.debug(
                    "Received so far %s (%d bytes, %s%%)",
                    common.utils.size2str(self.received), self.received, curp
                )
            self.progress = curp
            if not pkg.data:
                self.finished = True

            chunk = self.buffer + pkg.data
            if len(chunk) > amount:
                chunk, self.buffer = chunk[:amount], chunk[amount:]
            else:
                self.buffer = ""
        self.sha1.update(chunk)
        return chunk


class DirectStreamReader(ABCStreamReader):
    """ A helper class which will consume plain file data stream. """

    def __init__(self, task, meta, reader):
        super(DirectStreamReader, self).__init__(task, meta["limit"], reader)
        self.common_prefix = meta["name"]

    def extractall(self, _):
        with open(self.common_prefix, "wb") as fh:
            while True:
                chunk = self.read(0x2FFFFF)
                if not chunk:
                    break
                fh.write(chunk)


class TarballStreamReader(ABCStreamReader):
    """ A helper class, which will consume data queue and will provide a data for tarball stream processor. """

    def extractall(self, path):
        with tarfile.open(mode="r|", fileobj=self, bufsize=0x2FFFFF) as tar:
            tar.extractall(path, members=self.members(tar))
        self.common_prefix = self.common_prefix.strip(os.path.sep)

    def members(self, members):
        for tarinfo in members:
            logging.debug("Extracting %r", tarinfo.name)
            prefix = tarinfo.name.split(os.path.sep)[0]
            if prefix == "tmp":
                raise ValueError("'tmp' is reserved directory name, use different name")
            self.common_prefix = self.common_prefix or prefix
            if prefix != self.common_prefix:
                raise common.errors.TaskFailure(
                    "There are no common directory of uploaded files: '{}' differs with '{}'".format(
                        self.common_prefix, prefix
                    )
                )
            tarinfo.mode |= 0o444  # Allow anybody read the file
            yield tarinfo


class HTTPUpload2(nanny.ReleaseToNannyTask2, release_integration.ReleaseToYaDeployTask2, sdk2.Task):
    """
    "Backend" task for resource upload via proxy.sandbox.yandex-team.ru using HTTP protocol.
    It should **not** be used directly, without :module:`common.upload` module.
    """
    name = "HTTP_UPLOAD_2"

    class Requirements(sdk2.Requirements):
        client_tags = (ctc.Tag.GENERIC | ctc.Tag.STORAGE) & ctc.Tag.HDD

    class Parameters(sdk2.Parameters):
        resource_type = sdk2.parameters.String("Resource type to be created", required=True)
        resource_attrs = sdk2.parameters.Dict("Resource attributes", required=True)
        resource_arch = sdk2.parameters.String("Resource arch to be set on resource creation", required=True)

        with sdk2.parameters.Output:
            resource = sdk2.parameters.Resource("Resource with upload data", required=True)

    class Context(sdk2.Context):
        upload = {}
        resource_id = None

    def handle_greetings(self, greets):
        self.Context.upload.pop("received", None)
        logging.info("Received %s", greets)
        assert greets.cnonce == self.Context.upload["cnonce"]
        self.Context.upload["proxy"] = {"node_id": greets.node_id, "request_id": greets.request_id}
        self.server.task.current.context.value.update(key="upload", value=self.Context.upload)

    def update_stream_status(self, received):
        self.Context.upload["received"] = received
        self.server.task.current.context.value.update(key="upload", value=self.Context.upload)
        return received

    def on_prepare(self):
        assert ctm.Upload.VERSION >= self.Context.upload["version"]

        logging.debug("Creating a listening socket...")
        sh = SocketHandler()

        logging.info("Listening on %s:%r. Waiting for greetings.", *sh.addr)
        self.Context.upload["cnonce"] = random.getrandbits(24) & 0xFFFFF
        self.Context.upload["target"] = ":".join(map(str, sh.addr))
        self.server.task.current.context.value.update(key="upload", value=self.Context.upload)

        logging.debug("Registering a resource...")
        attrs = {"backup_task": True}
        attrs.update(self.Parameters.resource_attrs)
        self.Parameters.resource = sdk2.Resource[self.Parameters.resource_type](
            self,
            self.Parameters.description,
            "UNKNOWN",
            self.Parameters.resource_arch,
            **attrs
        )

        logging.debug("Waiting for proxy to connect...")
        self.handle_greetings(iter(sh).next())

    def on_execute(self):
        sh = SocketHandler()
        reader = (
            DirectStreamReader(self, self.Context.upload["stream"], iter(sh))
            if isinstance(self.Context.upload["stream"], dict) else
            TarballStreamReader(self, self.Context.upload["amount"], iter(sh))
        )
        logging.info(
            "Consuming data stream of about %s and extract files to %r",
            common.utils.size2str(reader.expected), self.path()
        )
        reader.extractall(str(self.path()))
        logging.debug(
            "Stream processing finished at %d bytes. Consuming the rest of the stream to provide correct checksum.",
            reader.received
        )
        while not reader.finished and reader.read(0xFFFFF):
            pass
        sha1 = self.Context.upload["checksum"] = reader.sha1.hexdigest()
        self.Context.upload["received"] = str(reader.received)  # Poor XMLRPC
        logging.info(
            "Received %s (%s bytes), SHA1: %s",
            common.utils.size2str(reader.received), reader.received, sha1
        )
        logging.info("Common name is %r", reader.common_prefix)

        if not reader.common_prefix:
            raise common.errors.TaskFailure("No common prefix of uploaded files detected.")

        summary = iter(sh).next()
        assert isinstance(summary, sh.Proto.StreamSummary)
        if reader.received != summary.size:
            raise common.errors.TaskFailure(
                "Received tarball stream of {} bytes, while it should be {} bytes.".format(
                    reader.received, summary.size
                )
            )
        if self.Context.upload["version"] > 1 and summary.sha1 and sha1 != summary.sha1:
            raise common.errors.TaskFailure(
                "Received tarball stream SHA1 '{}', while it should be '{}'.".format(
                    summary.sha1, sha1
                )
            )

        self.Parameters.resource.path = reader.common_prefix

    def on_release(self, additional_parameters):
        nanny.ReleaseToNannyTask2.on_release(self, additional_parameters)
        try:
            self.get_yp_oauth_token()
        except Exception:
            logging.warning("Can't get yp oauth token", exc_info=True)
            self.set_info("Can't get yp oauth token, skipping deploy release")
        else:
            release_integration.ReleaseToYaDeployTask2.on_release(self, additional_parameters)
