# coding: utf-8

import os
import abc
import time
import struct
import signal
import socket
import random
import logging
import datetime as dt
import Queue as queue
import threading as th

import common.types.misc as ctm

import common.zk
import common.log
import common.utils
import common.config
import common.patterns

from yasandbox.database import mapping
from sandbox.yasandbox import controller


class ThreadWithZK(th.Thread):
    """
    Service thread with (optional) Zookeeper lock.
    In case of lock used, the thread will store its state in the database.
    """
    __metaclass__ = abc.ABCMeta

    Model = mapping.Service
    WaitST = struct.Struct("L")
    # Maxumum amount of time in seconds to wait for child process graceful exit
    CHILD_WAIT_TIMEOUT = 15
    # Timeout for juggler monitoring script, in minutes
    NOTIFICATION_TIMEOUT = 30
    # Process-wide lock to be acquired on fork
    __fork_lock = th.RLock()
    # Process-wide collection with children's control sockets
    __sockets = []

    class Command(common.utils.Enum):
        """ Child process IPC protocol specification :) """
        SLEEP = "0"     # Actually, this command never send — its for suffixes only
        RUN = "1"       # Worker should execute a next iteration loop and return wait time
        DIE = "4"       # Worker should exit

    PREFIX = "[sandbox] Service "
    SUFFIX = {
        Command.RUN: " (running)",
        Command.SLEEP: " (sleeping)",
        Command.DIE: " (exiting)",
    }

    class LockStats(common.patterns.Abstract):
        __metaclass__ = common.utils.SingletonMeta
        __slots__ = ("instances",)
        __defs__ = (set(),)

        @property
        def acquired(self):
            return sum(_.is_acquired for _ in set(self.instances))

    def __init__(self, *args, **kwargs):
        # The stopping flag function should be replaced on thread run.
        self.stopping = self.service_stopping = kwargs.pop('stopping')
        self.core_log = kwargs.pop('logger')
        self.rwlock = kwargs.pop('rwlock')
        self.alarm = None
        self._model = None
        super(ThreadWithZK, self).__init__(*args, **kwargs)
        self.zk = common.zk.Zookeeper()
        self.run_interval = 3  # sec
        self.logger = common.log.get_core_log('jobs')
        common.statistics.Signaler(
            common.statistics.ServerSignalHandler(),
            component=ctm.Component.SERVICE,
            update_interval=common.config.Registry().server.statistics.update_interval,
        )

    def start(self):
        self.alarm = th.Event()
        super(ThreadWithZK, self).start()

    def wakeup(self):
        self.alarm.set()

    def wait(self, timeout):
        if self.alarm.wait(timeout):
            self.logger.info('B-Z-Z-Z-Z-Z-Z!!!')
            self.alarm.clear()
            return True
        return False

    @property
    def model(self):
        if not self._model:
            settings = common.config.Registry()
            name = self.__class__.__name__
            self._model = self.Model.objects.with_id(name)
            if not self._model:
                self.logger.info('Creating new database model object.')
                self._model = self.Model()
                self._model.name = name
                self._model.time = self.Model.Time()
                self._model.host = settings.this.id
                self._model.save()
        return self._model

    @property
    def first_run_delay(self):
        return 0 if self.zk.enabled else random.randint(self.run_interval / 2, self.run_interval * 2)

    @classmethod
    def read_only(cls):
        with cls.__fork_lock:
            return controller.Settings.mode() == controller.Settings.OperationMode.READ_ONLY

    def __check_operational_mode(self):
        if not self.read_only():
            return
        self.logger.warn('Sandbox is in READONLY mode. Waiting for normal operation mode.')
        while self.read_only():
            self.wait(1)

    def _on_subprocess_start(self):
        pass

    def _on_subprocess_stop(self):
        pass

    def _finalize_jobs(self):
        pass

    def _on_proc_done(self):
        self._finalize_jobs()

    def _on_signal(self):
        pass

    def _run_subprocess(self):
        from kernel.util import console

        self.logger.info("Forking a worker subprocess.")
        with self.__fork_lock:
            # Disconnect from the database before forking.
            mapping.disconnect()

            parent_socket, child_socket = socket.socketpair()
            parent_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 0)
            child_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 0)
            parent_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 0)
            child_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 0)
            pid = None
            try:
                logging._acquireLock()
                pid = os.fork()
            finally:
                try:
                    logging._releaseLock()
                except RuntimeError:
                    pass

            if pid:
                child_socket.close()
            self.__sockets.append(parent_socket)
            # Restore database connection after fork
            who = "Parent" if pid else "Child"
            self.logger.debug("[%s] Connecting to DB ...", who)
            mapping.ensure_connection()
            self.logger.debug("[%s] Connection to DB successfully established", who)

        if pid:
            parent_socket.settimeout(self.CHILD_WAIT_TIMEOUT)
            try:
                if not parent_socket.recv(1):
                    raise socket.error("No data received on a socket")
            except socket.error as ex:
                self.logger.error("Child process didn't respond after start in a reasonable time: %s", ex)
                self._kill_subprocess(None, pid)
                parent_socket.close()
                return self._run_subprocess()
            parent_socket.settimeout(None)
            return parent_socket, pid

        def sighandler(*_):
            self.service_stopping = True
            self._on_signal()

        self._model = None
        common.zk.Zookeeper().client = None
        # Close all parent's control sockets
        map(lambda _: _.close(), self.__sockets)

        prefix = self.PREFIX + self.__class__.__name__
        console.setProcTitle(prefix)

        self.service_stopping = False
        signal.signal(signal.SIGINT, sighandler)
        signal.signal(signal.SIGTERM, sighandler)
        self.stopping = lambda: self.service_stopping

        self.model.reload()
        self.model.host = common.config.Registry().this.id
        self.model.save()

        self._on_subprocess_start()
        self.logger.info("Subprocess started.")
        console.setProcTitle(prefix + self.SUFFIX[self.Command.SLEEP])
        child_socket.send(self.Command.RUN)
        try:
            while not self.service_stopping:
                try:
                    data = child_socket.recv(1)
                except socket.error:
                    data = None
                if not data or data == self.Command.DIE:
                    self.logger.info("Signal socket closed.")
                    break
                console.setProcTitle(prefix + self.SUFFIX[self.Command.RUN])
                wait = self._proc()
                self._on_proc_done()
                if wait is None:
                    wait = self._wait_before_next_run()
                now = dt.datetime.utcnow()
                self.model.time.last_run = now
                self.model.time.next_run = now + (wait or dt.timedelta(seconds=self.run_interval))
                self.model.timeout = self.NOTIFICATION_TIMEOUT
                self.model.save()
                child_socket.send(self.WaitST.pack(int(wait.total_seconds()) if wait else 0))
                console.setProcTitle(prefix + self.SUFFIX[self.Command.SLEEP])
        except:
            self.logger.exception("Unhandled exception in a service process.")
        finally:
            self._on_subprocess_stop()
            self.logger.info(
                "Subprocess %s.", "finished" if not self.service_stopping else "signaled to exit"
            )
            os._exit(0)

    def _kill_subprocess(self, sock, pid):
        kill = True
        if sock:
            self.__sockets.remove(sock)
            self.logger.info("Asking child process to exit gracefully.")
            try:
                sock.send(self.Command.DIE)
                for _ in xrange(self.CHILD_WAIT_TIMEOUT * 10):
                    try:
                        if os.waitpid(pid, os.WNOHANG)[0]:
                            raise OSError("Subprocess exited")
                    except OSError:
                        kill = False
                        break
                    time.sleep(0.1)
            except socket.timeout:
                self.logger.warning("Child process didn't respond in a reasonable time.")
            except socket.error:
                self.logger.warning("Child process stopped unexpectedly.")
            sock.close()

        if pid and kill:
            self.logger.warning("Killing subprocess #%d", pid)
            try:
                os.kill(pid, signal.SIGKILL)
            except OSError:
                pass
        if pid:
            try:
                os.waitpid(pid, 0)
            except OSError:
                pass

        return None, None

    def run(self):
        self.logger.info('Thread started. Run period: %s', common.utils.td2str(self.run_interval))
        self.core_log.info('Thread %s started.', self.__class__.__name__)

        if self.zk.enabled:
            self._run_with_zk()
        else:
            self._run()

        self.core_log.info('Thread %s stopped.', self.__class__.__name__)
        self.logger.info('Thread stopped.')

    def _run(self):
        sock, pid = None, None
        sleep = self.first_run_delay
        if sleep:
            self.logger.info('Sleeping %s on start.', common.utils.td2str(sleep))
            self.wait(sleep)

        try:
            while not self.service_stopping():
                wait = None
                self.__check_operational_mode()
                self.stopping = lambda: self.service_stopping() or self.read_only()
                while not self.stopping():
                    if not sock:
                        sock, pid = self._run_subprocess()
                    with self.rwlock.reader:
                        try:
                            sock.send(self.Command.RUN)
                            data = sock.recv(self.WaitST.size)
                            if not data:
                                raise RuntimeError("Child process closed connection unexpectedly.")
                            wait = self.WaitST.unpack(data)[0]
                        except Exception:
                            self.logger.exception("Unexpected exception on subprocess communication.")
                            sock, pid = self._kill_subprocess(sock, pid)

                    wait = wait or self.run_interval
                    if wait:
                        self.logger.debug('Waiting for %s', common.utils.td2str(wait))
                        self.wait(wait)
        finally:
            self._kill_subprocess(sock, pid)

    def _run_with_zk(self):
        lock = None
        iteration = 0
        sock, pid = None, None
        lst = self.LockStats()
        settings = common.config.Registry()
        lock_name = self.__class__.__name__

        while not self.service_stopping():
            try:
                self.__check_operational_mode()
                lst.instances.discard(lock)
                lock = self.zk.lock("jobs", lock_name)
                lst.instances.add(lock)
                self.logger.info(
                    "Acquiring 'jobs/%s' lock (acquired %d of %d). Session ID is 0x%x",
                    lock_name, lst.acquired, len(lst.instances), self.zk.client.client_id[0]
                )
                while not self.service_stopping():
                    lock.acquire()
                    break

                client_id = lock.client.client_id
                self.logger.info(
                    "Lock 'jobs/%s' has been acquired (acquired %d of %d). Session ID is 0x%x",
                    lock_name, lst.acquired, len(lst.instances), client_id[0]
                )
                if not sock:
                    sock, pid = self._run_subprocess()
                self.stopping = lambda: (
                    not lock.client.connected or lock.client.client_id != client_id or
                    self.service_stopping() or self.read_only()
                )

                with self.__fork_lock:
                    self.model.reload()
                    self.model.host = settings.this.id
                    self.model.save()
                    now = dt.datetime.utcnow()
                    wait = (self.model.time.next_run - now).total_seconds()
                while not self.stopping():
                    acquired = lst.acquired
                    contenders = len(lock.contenders())
                    wait = wait or self.run_interval
                    release = min(
                        0 if contenders < 2 else (acquired - 1) * float(len(lst.instances)) / (contenders - 1),
                        wait
                    )
                    self.logger.debug(
                        'Iteration #%d. Locks acquired %d of %d. Lock contenders %d. Release %.2f',
                        iteration, acquired, len(lst.instances), contenders, release
                    )
                    if iteration < 5 and release > 0:
                        self.logger.info("Releasing the lock and sleeping %.2fs.", release)
                        sock, pid = self._kill_subprocess(sock, pid)
                        lock.release()
                        self.wait(release)
                        break

                    if wait > 0:
                        self.logger.debug('Waiting for %s', common.utils.td2str(wait))
                        self.wait(wait)
                        if self.stopping():
                            break
                    else:
                        self.logger.info('Delayed for for %s', common.utils.td2str(abs(wait)))

                    with self.rwlock.reader:
                        wait = None
                        try:
                            sock.send(self.Command.RUN)
                            data = sock.recv(self.WaitST.size)
                            if not data:
                                raise RuntimeError("Child process closed connection unexpectedly.")
                            wait = self.WaitST.unpack(data)[0]
                        except Exception:
                            lock.release()
                            self.logger.exception("Unexpected exception on subprocess communication.")
                            sock, pid = self._kill_subprocess(sock, pid)
                            break
                    iteration += 1
            except Exception:
                if self.service_stopping():
                    return
                self.logger.exception("Unexpected exception in service thread loop.")
                self.wait(10)
            finally:
                sock, pid = self._kill_subprocess(sock, pid)
                if lock:
                    lock.release()

    @abc.abstractmethod
    def _proc(self):
        """
        An abstract method, which should be overridden by derived class.
        Can return `None` or `datetime.timedelta` which will report to the main loop the period,
        which the loop should wait before running the method again.
        """
        pass

    def _wait_before_next_run(self):
        """
        :return: timedelta to wait before next run
        :rtype: `datetime.timedelta` or None
        """


class Multithreaded(ThreadWithZK):
    """ A service process with a threads pool to process jobs in parallel. """

    # Amount of worker threads pool
    WORKERS = 10

    def __init__(self, *args, **kwargs):
        self._threads = None
        self._queue = None
        super(Multithreaded, self).__init__(*args, **kwargs)

    @abc.abstractmethod
    def _worker_proc(self, data):
        """
        A thread worker's main entry point - it will be called for each data, which will be put to the `self._queue`
        """
        pass

    def _worker(self):
        try:
            while not self.stopping():
                data = self._queue.get()
                try:
                    if data is None:
                        break
                    self._worker_proc(data)
                finally:
                    self._queue.task_done()
        except:
            self.logger.exception("Unhandled exception in a worker thread")
        finally:
            self.service_stopping = True

    def _on_subprocess_start(self):
        self._queue = queue.Queue()
        self.logger.info("Starting %d worker threads.", self.WORKERS)
        self._threads = [th.Thread(target=self._worker, name="#" + str(_)) for _ in xrange(self.WORKERS)]
        map(th.Thread.start, self._threads)

    def _on_subprocess_stop(self):
        self.logger.info("Waiting for worker threads exit.")
        map(self._queue.put, (None,) * len(self._threads))
        map(th.Thread.join, self._threads)

    def join_queue(self, queue):
        queue.all_tasks_done.acquire()
        try:
            while queue.unfinished_tasks:
                queue.all_tasks_done.wait(5)
        finally:
            queue.all_tasks_done.release()

    def _finalize_jobs(self):
        self.logger.info("Waiting for workers to finish their jobs.")
        if self._queue is not None:
            self.join_queue(self._queue)
        self.logger.info("All jobs are finished")

    def _on_signal(self):
        while not self._queue.empty():
            try:
                self._queue.get(block=False)
                self._queue.task_done()
            except queue.Empty:
                pass
