from __future__ import absolute_import

import os
import sys
import abc
import time
import stat
import errno
import Queue
import signal
import struct
import socket
import select
import getpass
import cPickle
import logging
import platform
import threading as th
import collections

from . import os as common_os
from . import utils
from . import config
from .types import misc as ctm


class InterProcessQueue(object):
    SIZE_ST = struct.Struct("<I")

    Worker = collections.namedtuple("Worker", "socket thread")

    def __init__(self, logger=None):
        self.__logger = logger or logging
        self.__input_queue = Queue.PriorityQueue()
        self.__output_queue = Queue.Queue()
        self.__worker = None
        self.__workers = {}
        self.__result_thread = None

    def __worker_thread(self, pid, sock):
        def thread():
            self.__logger.info("Starting thread for worker process #%s", pid)
            data = None
            try:
                while True:
                    data = self.__input_queue.get()
                    if data is None:
                        self.__send(sock, self.__pack(data))
                        break
                    self.__send(sock, data)
            except Exception as ex:
                if data is not None:
                    self.__input_queue.put(data)
                self.__logger.error("Error occurred in worker thread for process #%s: %s", pid, ex)
            finally:
                self.__logger.info("Terminating thread for worker process #%s", pid)

        t = th.Thread(target=thread)
        t.daemon = True
        t.start()
        return t

    def __pack(self, value):
        data = cPickle.dumps(value, -1)
        return "".join((self.SIZE_ST.pack(len(data)), data))

    @staticmethod
    def __unpack(data):
        return cPickle.loads(data)

    @staticmethod
    def __send(sock, data):
        sent = 0
        while sent != len(data):
            try:
                sent += sock.send(data[sent:])
            except socket.error as ex:
                if ex.errno == errno.EINTR:
                    continue
                raise
        return sent

    def __recv(self, sock):
        size = 0
        while True:
            try:
                data = sock.recv(self.SIZE_ST.size)
                if data:
                    size = self.SIZE_ST.unpack(data)[0]
                break
            except socket.error as ex:
                if ex.errno == errno.EINTR:
                    continue
                raise
        chunks = []
        while size:
            try:
                chunk = sock.recv(size)
                if not chunk:
                    return chunk
            except socket.error as ex:
                if ex.errno == errno.EINTR:
                    continue
                raise
            size -= len(chunk)
            chunks.append(chunk)
        return "".join(chunks)

    def __worker_get(self, timeout=None):
        ready = None
        while True:
            try:
                ready = select.select([self.__worker.socket], [], [], timeout)[0]
                break
            except select.error as ex:
                if ex.args[0] == errno.EINTR:
                    continue
                raise
        if not ready:
            raise Queue.Empty
        return self.__recv(self.__worker.socket)

    def __worker_put(self, data):
        self.__send(self.__worker.socket, data)

    def put(self, value):
        (self.__worker_put if self.__worker else self.__input_queue.put)(self.__pack(value))

    def get(self, timeout=None):
        data = (self.__worker_get if self.__worker else self.__output_queue.get)(timeout=timeout)
        return data and self.__unpack(data)

    def put_to_output(self, value):
        self.__output_queue.put(value)

    def qsize(self):
        return self.__input_queue.qsize()

    def start(self):
        def thread():
            pid = os.getpid()
            self.__logger.info("Starting result thread of main process #%s", pid)
            try:
                while True:
                    try:
                        rlist = {
                            worker.socket: pid
                            for pid, worker in self.__workers.iteritems()
                            if worker
                        }
                        if not rlist:
                            break
                        ready = select.select(rlist, [], [])[0]
                    except select.error as ex:
                        if ex.args[0] == errno.EINTR:
                            continue
                        raise
                    if ready:
                        for sock in ready:
                            try:
                                data = self.__recv(sock)
                            except socket.error as ex:
                                if ex.errno not in (errno.EPIPE, errno.ECONNRESET):
                                    raise
                                data = None
                            if not data:
                                worker_pid = rlist[sock]
                                self.__logger.info(
                                    "Removing worker #%s, probably socket closed by remote peer", worker_pid
                                )
                                self.__workers[worker_pid] = None
                                continue
                            self.__output_queue.put(data)
            except Exception as ex:
                self.__logger.error("Error occurred in result thread of main process #%s: %s", pid, ex)
            finally:
                self.__logger.info("Terminating result thread of main process #%s", pid)

        t = th.Thread(target=thread)
        t.daemon = True
        t.start()
        self.__result_thread = t

    def stop(self):
        for _ in self.__workers:
            self.__input_queue.put(None)

    def fork(self):
        assert self.__worker is None, "Can fork from main process only"
        main_socket, worker_socket = socket.socketpair()
        main_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 0)
        worker_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 0)
        main_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 0)
        worker_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 0)
        pid = os.fork()
        if pid == 0:
            main_socket.close()
            self.__worker = self.Worker(worker_socket, None)
            self.__workers = {}
        else:
            worker_socket.close()
            self.__workers[pid] = self.Worker(main_socket, self.__worker_thread(pid, main_socket))
        return pid

    def join(self):
        for pid in self.__workers:
            try:
                return os.waitpid(pid, 0)
            except OSError as ex:
                if ex.errno != errno.ECHILD:
                    raise


class Master(object):
    """
    The class manages sub-process(es), which will be created. The instance will handle a queue, which will
    be processed by instance(s) of `Slave` class (i.e., this one is producer).
    Singleton. Does NOT support multi-threaded initialization.
    """
    __metaclass__ = utils.SingletonMeta

    def __init__(self, logger):
        self.logger = logger
        # Communication pipe to sub-process (or master process).
        self.queue = InterProcessQueue(logger=self.logger)
        # Sub-process instance(s).
        self.processes = None
        super(Master, self).__init__()

    def __del__(self):
        # noinspection PyBroadException
        try:
            self.destroy()
        except:
            pass

    @abc.abstractmethod
    def _slaves_builder(self):
        """
        Generator. Should return an instance(es) of `Slave`-derived object.
        """
        yield  # Just to avoid code inspector warnings

    def init(self):
        """
        Initializes worker process(es) pool with given class and arguments generator.
        """
        self.logger.debug("Initializing sub-process workers.")
        self.processes = list(self._slaves_builder())
        for obj in self.processes:
            assert isinstance(obj, Slave), "First argument should be a subclass of `process.Slave` class."
        # Avoid any future initializations
        setattr(self, "init", lambda *_: None)
        return self

    def destroy(self):
        """
        Last-chance method to destroy worker process(es). Just kill "em all!
        """
        for process in filter(Slave.is_alive, self.processes):
            try:
                os.kill(process.pid, signal.SIGKILL)
                self.logger.warning("Killed process worker #%d", process.pid)
            except OSError:
                pass

    def start(self):
        """
        Starts a pool of worker process(es).
        """
        for p in self.processes:
            p.start()
            self.logger.info("Started worker sub-process with PID #%d.", p.pid)
        self.queue.start()
        return self

    def stop(self):
        """
        Signals worker process(es) to finish.
        """
        self.logger.info("Signaling %r sub-process workers.", [_.pid for _ in self.processes])
        self.queue.stop()
        return self

    def process(self, data):
        """
        Sends data given to process by some worker.

        :param data: data to process
        """
        self.queue.put(data)

    def join(self, maxwait=None):
        """
        Waits for worker process(es) to finish. If `maxwait` provided, kills unfinished worker process(es).
        :param maxwait: Maximum wait time in seconds.
        """

        msg = "Waiting for sub-process workers to terminate"
        event = None
        waiter = None
        if maxwait:
            def wait(ev):
                if not ev.wait(maxwait):
                    self.logger.error("Sub-process workers were not exited in reasonable time. Killing them.")
                    self.destroy()

            event = th.Event()
            waiter = th.Thread(target=wait, args=(event,))
            waiter.start()
            msg += " in maximum {}s".format(maxwait)
        self.logger.debug(msg)

        map(Slave.join, self.processes)
        if event:
            event.set()
            waiter.join()
        self.logger.info("All sub-process workers are stopped.")


class Process(object):
    def __init__(self, queue):
        self.__queue = queue
        self.__pid = None

    @property
    def pid(self):
        return self.__pid

    @property
    def queue(self):
        return self.__queue

    def join(self):
        try:
            return os.waitpid(self.__pid, 0)
        except OSError as ex:
            if ex.errno != errno.ECHILD:
                raise

    def is_alive(self):
        try:
            return (os.kill(self.__pid, 0) or True) if self.__pid else False
        except OSError:
            return False

    def start(self):
        self.__pid = (os.fork if self.queue is None else self.queue.fork)()
        if not self.__pid:
            try:
                self.__pid = os.getpid()
                self.run()
            finally:
                # noinspection PyProtectedMember
                os._exit(0)

    def run(self):
        pass


class Slave(Process):
    """
    Sub-process worker base class. The instance of the class should be produced by `Master.start()` method
    and will operate on master's queue (i.e., this one(es) will be consumer(s)).
    """
    __metaclass__ = abc.ABCMeta

    def __init__(self, log, pidfile, queue):
        self.logger = log
        self.mypid = None
        self.pidfile = pidfile
        super(Slave, self).__init__(queue)

    def _check_pid(self):
        # noinspection PyBroadException
        try:
            with open(self.pidfile) as fh:
                pid = int(fh.readline().strip())
                if pid == self.mypid:
                    raise ValueError("Cannot kill myself.")
                self.logger.warning("Killing previous instance with PID #%d", pid)
                os.kill(pid, signal.SIGTERM)
                while True:
                    time.sleep(1)
                    os.kill(pid, signal.SIGKILL)
        except:
            pass
        open(self.pidfile, "a").close()
        with open(self.pidfile, "r+") as f:
            data = str(os.getpid())
            data += " " * (16 - len(data))
            f.write(data)

    @abc.abstractmethod
    def process(self, data):
        """
        Processes a given data.
        This method is abstract and should be implemented in derived class.

        :param data: data to process
        """

    def sighandler(self, signum, *_):
        """
        Default signal handler will not terminate the process, but will request the main loop to break.

        :param signum: signal number
        :param _: unused parameters
        """
        self.logger.warn("Process #%s worker: caught signal %d. Terminating.", self.mypid, signum)
        # noinspection PyProtectedMember
        os._exit(signum)

    def on_start(self):
        """
        Method should be overridden to perform custom actions just after worker start.
        """

    def on_stop(self):
        """
        Method should be overridden to perform custom actions just before worker stop.
        """

    def main(self):
        while True:
            data = self.queue.get()
            if not data:
                break
            self.process(data)

    def run(self):
        for sig in (signal.SIGHUP, signal.SIGINT, signal.SIGTERM):
            signal.signal(sig, self.sighandler)

        self.mypid = os.getpid()
        self._check_pid()
        # noinspection PyBroadException
        try:
            self.on_start()
        except:
            self.logger.exception("Process #%s worker: unhandled exception during initialization.", self.mypid)
            return

        self.logger.info("Process #%s worker: started.", self.mypid)
        # noinspection PyBroadException
        try:
            self.main()
        except:
            self.logger.exception("Process #%s worker: unhandled exception in the main loop.", self.mypid)
        finally:
            # noinspection PyBroadException
            try:
                self.on_stop()
            except:
                self.logger.exception("Process #%s worker: unhandled exception during finalization.", self.mypid)

        self.logger.info("Process #%s worker: stopped.", self.mypid)
        with open(self.pidfile, "r+") as f:
            f.write(" " * 16)


def processes():
    """
    The method is a simple replacement for :py:mod:`psutil`. It will list `/proc` directory and yield
    each found process basic information (UID, GID and PID).
    The method is suitable to work on Linux, FreeBSD and Cygwin.
    """

    proc_t = collections.namedtuple("Proc", "pid uid gid exe")
    if platform.system().startswith("CYGWIN"):
        for dname in os.listdir("/proc"):
            try:
                pid = int(dname)
                st = os.stat(os.path.join("/proc", dname))
                if st.st_mode & stat.S_IFDIR:
                    yield proc_t(pid, st.st_uid, st.st_gid, None)
            except (ValueError, OSError):
                pass
    else:
        # noinspection PyUnresolvedReferences
        import psutil

        def __safe_exe(p):
            try:
                return p.exe
            except (psutil.NoSuchProcess, psutil.AccessDenied):
                return p.name

        for p in psutil.process_iter():
            try:
                yield proc_t(p.pid, p.uids.real, p.gids.real, __safe_exe(p))
            except (psutil.NoSuchProcess, psutil.AccessDenied):
                pass


def run_as_root():
    """
    The method will replace the current process with the same but under root privileges
    if password-less `sudo` is available.
    """

    try:
        if (
            config.Registry().common.installation == ctm.Installation.LOCAL and
            not os.environ.get(common_os.User.SERVICE_USER_ENV) and
            not common_os.User.has_root and common_os.User.can_root
        ):
            os.execv(
                "/usr/bin/sudo",
                [
                    "[sandbox] New Service",
                    "-En", "PYTHONPATH=/skynet", "=".join((common_os.User.SERVICE_USER_ENV, getpass.getuser())),
                    sys.executable, sys.argv[0]
                ]
            )
    except OSError:
        pass
