from __future__ import absolute_import

import io
import os
import sys
import uuid
import errno
import gevent
import select
import signal
import struct
import inspect
import logging
import itertools as it
import threading as th
import gevent.lock
import gevent.event

import six
from six.moves import cPickle

# noinspection PyUnresolvedReferences,PyPackageRequirements
from kernel.util import console

from . import user

PROCESS_TITLE_DELIMITER = " :: "


def reinit_gevent():
    # reset signal handlers
    for s in (
        getattr(signal, s) for s in
        {s for s in dir(signal) if s.startswith("SIG")} -
        # exclude constants that are not signals such as SIG_DFL and SIG_BLOCK.
        {s for s in dir(signal) if s.startswith("SIG_")} -
        # leave handlers for SIG(STOP/KILL/PIPE) untouched.
        {"SIGSTOP", "SIGKILL", "SIGPIPE"}
    ):
        signal.signal(s, signal.SIG_DFL)

    gevent.reinit()
    hub = gevent.get_hub()
    del hub.threadpool
    hub._threadpool = None
    # FIXME: workaround to avoid infinite hang, fix after https://github.com/gevent/gevent/issues/1669
    orig_throw, hub.throw = hub.throw, lambda *_: None
    try:
        hub.destroy(destroy_loop=True)
    finally:
        hub.throw = orig_throw
    h = gevent.get_hub(default=True)
    assert h.loop.default, "Could not create libev default event loop."


def reset_logging_locks():
    # noinspection PyUnresolvedReferences,PyProtectedMember
    logging._lock = type(logging._lock)()
    handlers = set()
    # noinspection PyUnresolvedReferences
    for logger in logging.Logger.manager.loggerDict.values():
        if isinstance(logger, logging.PlaceHolder):
            continue
        while logger is not None:
            for handler in logger.handlers:
                if isinstance(handler, logging.PlaceHolder) or handler.lock is None:
                    continue
                handlers.add(handler)
            logger = logger.parent
    for handler in handlers:
        handler.lock = type(handler.lock)()


class SubprocessAborted(Exception):
    pass


class Subprocess(object):
    """
    Helper class to perform some action required running in a subprocess in form of context manager.

    .. note::

    Any variable modifications inside the block will be lost.
    """

    MAX_STDERR_SIZE = 64 << 10

    SkipBodyException = Exception()
    recursive = False
    current = None
    lock = th.RLock()

    def __init__(
        self, title=None, check_status=False, logger=None, using_gevent=False, privileged=False, watchdog=False,
        on_fork=None, silent_exceptions=None
    ):
        self.pid = None
        self.__pipe = None
        self.__title = title
        self.__check_status = check_status
        self.__logger = logger
        self.__using_gevent = using_gevent
        self.__privileged = privileged
        self.__watchdog = watchdog
        self.__watchdog_event = gevent.event.Event() if using_gevent else th.Event()
        self.__on_fork = on_fork
        self.__silent_exceptions = silent_exceptions
        self.__stderr = None
        self.__signal_pipe = None
        self.result = None

    def __trace(self, frame, event, arg):
        raise self.SkipBodyException

    def __read_fd(self, fd, max_size=None):
        if self.__using_gevent:
            gevent.os.make_nonblocking(fd)
            read_fn = gevent.os.nb_read
        else:
            read_fn = os.read
        buf = io.BytesIO()
        size = 0
        while True:
            while True:
                try:
                    data = read_fn(fd, io.DEFAULT_BUFFER_SIZE)
                    break
                except OSError as ex:
                    if ex.errno != errno.EINTR:
                        raise
            if not data:
                break
            size += len(data)
            buf.write(data)
            if max_size and size >= max_size:
                break
        return buf.getvalue()

    def __watchdog_fn(self):
        if not self.__watchdog_event.wait(self.__watchdog):
            os.kill(self.pid, signal.SIGKILL)

    def __watchdog_stop_fn(self):
        self.__read_fd(self.__signal_pipe)
        self.__watchdog_event.set()
        os.close(self.__signal_pipe)

    def __enter__(self):
        if type(self).recursive:
            return

        with self.lock:
            pr, pw = os.pipe()
            er, ew = os.pipe()
            sr, sw = os.pipe()
            self.pid = gevent.os.fork() if self.__using_gevent else os.fork()

        if self.pid:
            os.close(pw)
            os.close(ew)
            os.close(sw)
            self.__pipe = pr
            self.__stderr = er
            self.__signal_pipe = sr
            if self.__watchdog:
                th.Thread(target=self.__watchdog_fn).start()
                th.Thread(target=self.__watchdog_stop_fn).start()
            # Parent process - skip body execution.
            # OMG! Do some magic
            sys.settrace(lambda *args, **keys: None)
            frame = inspect.currentframe().f_back
            frame.f_trace = self.__trace
        else:
            reset_logging_locks()
            if self.__on_fork:
                self.__on_fork()
            if self.__using_gevent:
                reinit_gevent()
            os.close(pr)
            os.close(er)
            os.close(sr)
            self.__pipe = pw
            self.__stderr = ew
            self.__signal_pipe = sw
            type(self).recursive = True
            type(self).current = self
            if self.__check_status:
                os.dup2(ew, sys.stderr.fileno())
            if self.__title:
                console.setProcTitle(self.__title)
            if self.__privileged:
                user.User.Privileges().__enter__()

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.pid:
            result_data = self.__read_fd(self.__pipe)
            _, exit_status = os.waitpid(self.pid, 0)
            if exit_status and self.__check_status:
                error = self.__read_fd(self.__stderr, self.MAX_STDERR_SIZE)
                if self.__logger:
                    self.__logger.error(
                        "Subprocess '%s' exited with status %s, stderr: %s", self.__title, exit_status, error
                    )
            os.close(self.__pipe)
            os.close(self.__stderr)
            self.__watchdog_event.set()
            if result_data:
                self.result, exception = cPickle.loads(result_data)
                if exception:
                    raise exception
            else:
                raise SubprocessAborted()
            if exc_val and exc_val is self.SkipBodyException:
                return True
        elif self.pid is not None:
            if exc_val and (not self.__silent_exceptions or not isinstance(exc_val, self.__silent_exceptions)):
                if self.__logger:
                    self.__logger.exception("Error in subprocess:")
            # noinspection PyBroadException
            try:
                data = cPickle.dumps((self.result, exc_val))
            except Exception as ex:
                data = cPickle.dumps((None, ex))
            try:
                os.write(self.__pipe, data)
            except OSError:
                self.__logger.exception("Error in subprocess:")
            try:
                os.close(self.__pipe)
            except OSError:
                self.__logger.exception("Error in subprocess:")
            try:
                os.close(self.__stderr)
            except OSError:
                self.__logger.exception("Error in subprocess:")
            # noinspection PyProtectedMember,PyUnresolvedReferences
            os._exit(1 if exc_val else 0)

    def stop_watchdog(self):
        if self.__signal_pipe is not None:
            os.close(self.__signal_pipe)
            self.__signal_pipe = None


class PipeQueue(object):
    SIZE_ST = struct.Struct("<I")

    def __init__(self, size=1, using_gevent=False):
        self.__max_queue_size = size
        self.__using_gevent = using_gevent
        self.__queue_size = 0
        self.__main_r, self.__worker_w = os.pipe()
        self.__worker_r, self.__main_w = os.pipe()
        if using_gevent:
            for fd in (self.__main_r, self.__worker_w, self.__worker_r, self.__main_w):
                gevent.os.make_nonblocking(fd)

    def __del__(self):
        for fd in (self.__main_r, self.__worker_w, self.__worker_r, self.__main_w):
            try:
                os.close(fd)
            except (OSError, IOError):
                pass

    def _get(self, fd, timeout=None):
        ready = None
        while True:
            try:
                ready = select.select([fd], [], [], timeout)[0]
                break
            except select.error as ex:
                if ex.args[0] == errno.EINTR:
                    continue
                raise
        if not ready:
            return False, None
        return True, self.__read(fd)

    def main_put(self, data):
        if self.__queue_size >= self.__max_queue_size:
            return False
        self.__queue_size += 1
        try:
            self.__write(self.__main_w, self.__pack(data))
        except BaseException:
            self.__queue_size -= 1
            raise
        return True

    def main_get(self, timeout=None):
        ok, data = self._get(self.__main_r, timeout=timeout)
        if ok:
            self.__queue_size -= 1
        return ok, (data if data is None else self.__unpack(data))

    def worker_put(self, data):
        self.__write(self.__worker_w, self.__pack(data))

    def worker_get(self, timeout=None):
        ok, data = self._get(self.__worker_r, timeout=timeout)
        return ok, (data if data is None else self.__unpack(data))

    def __pack(self, value):
        data = six.moves.cPickle.dumps(value, -1)
        return b"".join((self.SIZE_ST.pack(len(data)), data))

    @staticmethod
    def __unpack(data):
        return six.moves.cPickle.loads(data)

    def __read(self, fd):
        size = 0
        read_fn = gevent.os.nb_read if self.__using_gevent else os.read
        while True:
            try:
                data = read_fn(fd, self.SIZE_ST.size)
                if data:
                    size = self.SIZE_ST.unpack(data)[0]
                break
            except (IOError, OSError) as ex:
                if ex.errno == errno.EINTR:
                    continue
                raise
        chunks = []
        while size:
            try:
                chunk = read_fn(fd, size)
                if not chunk:
                    return chunk
            except (IOError, OSError) as ex:
                if ex.errno == errno.EINTR:
                    continue
                raise
            size -= len(chunk)
            chunks.append(chunk)
        return b"".join(chunks)

    def __write(self, fd, data):
        size = len(data)
        offset = 0
        write_fn = gevent.os.nb_write if self.__using_gevent else os.write
        while offset < size:
            written = write_fn(fd, data[offset:])
            offset += written


class PipeRPCException(Exception):
    pass


class PipeRPCBusyError(PipeRPCException):
    pass


class PipeRPCUnknownMethod(PipeRPCException):
    pass


class PipeRPC(PipeQueue):
    def __init__(self, using_gevent=False):
        super(PipeRPC, self).__init__(size=1, using_gevent=using_gevent)
        self.__lock = gevent.lock.RLock() if using_gevent else th.RLock()

    def __call__(self, method, *args, **kws):
        with self.__lock:
            ok = self.main_put((method, args, kws))
            if not ok:
                raise PipeRPCBusyError("Cannot make call: pipe is busy")
            ok, data = self.main_get()
            assert ok
            result, exception = data
            if exception is not None:
                raise exception
            return result


class PipeRPCServer(object):
    def __init__(self, rpc):
        self.__rpc = rpc
        self.__running = False
        super(PipeRPCServer, self).__init__()

    def __call__(self, timeout=None):
        self.__running = True
        while self.__running:
            ok, data = self.__rpc.worker_get(timeout=timeout)
            if not ok:
                continue
            method_name, args, kws = data
            method = getattr(self, method_name, None)
            result = None
            exception = None
            # noinspection PyBroadException
            try:
                if method is None:
                    raise PipeRPCUnknownMethod("method '{}' not found".format(method_name))
                result = method(*args, **kws)
            except BaseException as exc:
                exception = exc
            self.__rpc.worker_put((result, exception))

    def __stop__(self):
        self.__running = False


class SubprocessPool(object):
    def __init__(self, size, using_gevent=False, logger=None, silent_exceptions=None):
        self._size = size
        self._using_gevent = using_gevent
        self._logger = logger or logging
        self._silent_exceptions = silent_exceptions
        self._workers = {}
        self._subprocesses = {}
        self._results = {}
        self._pending = set()

    def _worker(self, job_id, title, watchdog, func, args, kws):
        sp_ctx = Subprocess(
            title=title, logger=self._logger, using_gevent=self._using_gevent,
            silent_exceptions=self._silent_exceptions, watchdog=watchdog,
        )
        self._subprocesses[job_id] = sp_ctx
        result = None
        exception = None
        self._pending.add(job_id)
        self._logger.debug("Starting subprocess job %s", job_id)
        # noinspection PyBroadException
        try:
            with sp_ctx:
                sp_ctx.result = func(*args, **kws)
            result = sp_ctx.result
        except BaseException as exc:
            exception = exc
        self._logger.debug("Subprocess job %s finished", job_id)
        self._pending.remove(job_id)
        self._results[job_id] = (result, exception)

    def _spawn(self, func, *args):
        if self._using_gevent:
            return gevent.spawn(func, *args)
        else:
            t = th.Thread(target=func, args=args)
            t.start()
            return t

    def spawn(self, func, args, kws, title=None, watchdog=None):
        if len(self._workers) >= self._size:
            return None
        job_id = uuid.uuid4().hex
        self._workers[job_id] = self._spawn(self._worker, job_id, title, watchdog, func, args, kws)
        return job_id

    def pending_jobs(self):
        return list(self._pending)

    def ready_jobs(self):
        return list(self._results.keys())

    def raw_result(self, job_id):
        self._workers.pop(job_id)
        self._subprocesses.pop(job_id, None)
        return self._results.pop(job_id)

    def result(self, job_id):
        result, exception = self.raw_result(job_id)
        if exception is not None:
            raise exception
        return result

    def kill(self, job_id, sig=signal.SIGTERM):
        sp_ctx = self._subprocesses.get(job_id)
        if sp_ctx is None or sp_ctx.pid is None:
            return
        os.kill(sp_ctx.pid, sig)


class ThreadPool(SubprocessPool):
    def _worker(self, job_id, title, watchdog, func, args, kws):
        result = None
        exception = None
        self._pending.add(job_id)
        # noinspection PyBroadException
        try:
            result = func(*args, **kws)
        except BaseException as exc:
            exception = exc
        self._pending.remove(job_id)
        self._results[job_id] = (result, exception)

    def kill(self, job_id, sig=signal.SIGTERM):
        pass


class WorkersPool(object):
    class RPCServer(PipeRPCServer):
        def __init__(self, rpc, func, pool_size, using_gevent=False, logger=None):
            self.__func = func
            self.__thread_pool = ThreadPool(pool_size, using_gevent=using_gevent, logger=logger)
            super(WorkersPool.RPCServer, self).__init__(rpc)

        @staticmethod
        def ping():
            return True

        def spawn(self, args, kws):
            return self.__thread_pool.spawn(self.__func, args, kws)

        def ready_jobs(self):
            return self.__thread_pool.ready_jobs()

        def result(self, job_id):
            return self.__thread_pool.result(job_id)

    def __init__(
        self, func, size, threads_pool_size, using_gevent=False, title=None, logger=None, silent_exceptions=None
    ):
        self.__func = func
        self.__size = size
        self.__threads_pool_size = threads_pool_size
        self.__using_gevent = using_gevent
        self.__title = title
        self.__logger = logger
        self.__pool = SubprocessPool(
            size, using_gevent=using_gevent, logger=logger, silent_exceptions=silent_exceptions
        )
        self.__running = False
        self.__rpc_clients = [PipeRPC(using_gevent) for _ in range(size)]
        self.__process_jobs = []
        self.__jobs = {}

    def start(self):
        if self.__running:
            return
        for i, rpc in enumerate(self.__rpc_clients):
            job_id = self.__pool.spawn(
                lambda rpc_: self.RPCServer(
                    rpc_, self.__func, self.__threads_pool_size, using_gevent=self.__using_gevent, logger=self.__logger
                )(),
                (rpc,), {}, title="{} #{}".format(self.__title, i)
            )
            self.__process_jobs.append(job_id)
            assert rpc("ping")
        self.__running = True

    def stop(self):
        if not self.__running:
            return
        for rpc in self.__rpc_clients:
            rpc("__stop__")

    def kill(self, sig=signal.SIGTERM):
        for job_id in self.__process_jobs:
            self.__pool.kill(job_id, sig)

    def spawn(self, *args, **kws):
        for i, rpc in enumerate(self.__rpc_clients):
            job_id = rpc("spawn", args, kws)
            if job_id is not None:
                self.__jobs[job_id] = rpc
                return job_id

    def ready_jobs(self):
        return list(it.chain.from_iterable((rpc("ready_jobs") for rpc in self.__rpc_clients)))

    def result(self, job_id):
        return self.__jobs[job_id]("result", job_id)


__all__ = [
    "PROCESS_TITLE_DELIMITER",
    "reinit_gevent",
    "reset_logging_locks",
    "Subprocess",
    "PipeQueue",
    "PipeRPC",
    "PipeRPCServer",
    "PipeRPCException",
    "PipeRPCBusyError",
    "PipeRPCUnknownMethod",
    "SubprocessPool",
    "SubprocessAborted",
    "WorkersPool",
]
