from __future__ import absolute_import

import os
import time
import errno
import struct
import signal
import socket
import gevent
import logging
import msgpack
import threading as th

import kazoo.client
import kazoo.exceptions

from sandbox.common import zk as common_zk
from sandbox.common import enum as common_enum
from sandbox.common import patterns as common_patterns
from sandbox.common import itertools as common_itertools


class Contender(object):
    SIZE_ST = struct.Struct("<I")
    CONTENDERS_NODE = "contenders"
    LOCK_NODE = "lock"
    WATCHDOG_JOIN_TIMEOUT = 5  # time in seconds to wait for termination of watchdog thread
    STOPPING_TIMEOUT = 1  # time in seconds to wait for terminating contender subprocess
    ZK_SUSPENDED_TIMEOUT = 15  # max time in seconds to stay in SUSPENDED state before terminating subprocess

    class Command(common_enum.Enum):
        CONTENDERS = "C"
        PRIMARY = "P"
        STOP = "S"

    class Retry(Exception):
        pass

    def __init__(
        self, hosts, zk_root, name, timeout=5, on_start=None, do_fork=True, logger=logging, lock_cls=th.RLock
    ):
        self.__hosts = common_zk.Zookeeper.hosts(hosts)
        self.__zk_root = zk_root
        self.__name = name
        self.__lock_path = self.__path(self.LOCK_NODE)
        self.__contenders_path = self.__path(self.CONTENDERS_NODE)
        self.__timeout = timeout
        self.__max_interval = timeout / 10.
        self.__on_start = on_start or (lambda *args: None)
        self.__do_fork = do_fork
        self.__logger = logger

        self.__pid = None
        self.__parent_socket = None
        self.__watchdog_socket = None

        self.__zk_session_actual = True
        self.__zk_suspended_kamikadze_thread = None

        self._primary_lock = None
        self._primary_lock_thread = None
        self.__primary = None

        self.__request_lock = lock_cls()

    @staticmethod
    def on_fork(joint_server):
        # initialize gevent to prevent the execution of the running greenlets after fork,
        # this code borrowed from gipc library
        joint_server.on_fork()
        gevent.reinit()
        hub = gevent.get_hub()
        del hub.threadpool
        hub._threadpool = None
        # FIXME: workaround to avoid infinite hang, fix after https://github.com/gevent/gevent/issues/1669
        orig_throw, hub.throw = hub.throw, lambda *_: None
        try:
            hub.destroy(destroy_loop=True)
        finally:
            hub.throw = orig_throw

    def _exit(self, status):
        if self.__do_fork:
            self.__logger.info("Exiting with status %s", status)
            # noinspection PyProtectedMember
            os._exit(status)

    def __path(self, *args):
        return "/".join(common_itertools.chain(self.__zk_root, args))

    def __zk_suspended_kamikadze(self):
        for i in range(self.ZK_SUSPENDED_TIMEOUT):
            self.__logger.warning(
                "ZK client is in SUSPENDED state: %ss left to subprocess termination",
                self.ZK_SUSPENDED_TIMEOUT - i
            )
            time.sleep(1)
            if self.__zk_session_actual:
                break
        else:
            self._exit(1)
        self.__logger.info("Stopping ZK suspended state kamikadze thread")
        self.__zk_suspended_kamikadze_thread = None

    def __zk_listener(self, state):
        self.__logger.info("Zookeeper state changed to %r", state)
        if state == kazoo.client.KazooState.LOST:
            self._exit(1)
        elif state == kazoo.client.KazooState.SUSPENDED:
            if not self.__zk_suspended_kamikadze_thread:
                self.__zk_suspended_kamikadze_thread = th.Thread(target=self.__zk_suspended_kamikadze)
                self.__zk_suspended_kamikadze_thread.daemon = True
                self.__zk_suspended_kamikadze_thread.start()
        self.__zk_session_actual = state == kazoo.client.KazooState.CONNECTED

    def _create_zk_client(self):
        # noinspection PyBroadException
        try:
            client = kazoo.client.KazooClient(
                hosts=self.__hosts,
                connection_retry={"max_tries": -1},
                randomize_hosts=False
            )
            client.add_listener(self.__zk_listener)
            client.start(timeout=self.__timeout)
            res, slept = common_itertools.progressive_waiter(
                0, self.__max_interval, self.__timeout, lambda: self.__zk_session_actual
            )
            if res is None:
                self.__logger.error("Zookeeper client could not be started in %s second(s)", slept)
                self._exit(1)
            return client
        except Exception:
            self.__logger.exception("Error occurred while creating of Zookeeper client")

    @common_patterns.singleton_property
    def __zk(self):
        zk_client, slept = common_itertools.progressive_waiter(
            0, self.__max_interval, self.__timeout, self._create_zk_client, sleep_first=False
        )
        if zk_client is None:
            self.__logger.error("Zookeeper client could not be started in %s second(s)", slept)
            self._exit(1)
        zk_client.ensure_path(self.__contenders_path)
        self._primary_lock = zk_client.Lock(self.__lock_path, self.__name)
        self.__primary = None
        return zk_client

    def __read_args(self, sock):
        data = sock.recv(self.SIZE_ST.size)
        size = int(self.SIZE_ST.unpack(data)[0])
        return msgpack.loads(sock.recv(size))

    def _loop(self, sock):
        self.__logger.info("Started")
        try:
            while True:
                # noinspection PyBroadException
                try:
                    try:
                        cmd = sock.recv(1)
                    except socket.error as ex:
                        if ex.errno == errno.EINTR:
                            continue
                        raise
                    if not cmd:
                        break
                    args = self.__read_args(sock)
                    result = ""
                    if cmd == self.Command.CONTENDERS:
                        if args:
                            result = self._contenders = args[0]
                        else:
                            result = self._contenders
                    elif cmd == self.Command.PRIMARY:
                        result = self._primary(*args)
                    elif cmd == self.Command.STOP:
                        self._stop()
                        break
                    result = msgpack.dumps(result)
                    sock.send("".join((self.SIZE_ST.pack(len(result)), result)))
                except:
                    self.__logger.exception("")
                    break
        finally:
            self.__logger.info("Stopped")
            self._exit(0)

    @staticmethod
    def __make_socket_pair():
        parent_socket, child_socket = socket.socketpair()
        parent_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 0)
        child_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 0)
        parent_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 0)
        child_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 0)
        return parent_socket, child_socket

    def start(self):
        self.__logger.info("Starting contender")
        if self.__do_fork:
            if self.__pid:
                return self
            parent_socket, child_socket = self.__make_socket_pair()
            watchdog_parent_socket, watchdog_child_socket = self.__make_socket_pair()
            self.__pid = gevent.os.fork()
            if not self.__pid:
                parent_socket.close()
                self.__logger = self.__on_start(self)
                self.__watchdog_socket = watchdog_child_socket
                # noinspection PyUnresolvedReferences
                self._loop(child_socket.dup())
            self.__parent_socket = parent_socket
            self.__watchdog_socket = watchdog_parent_socket
            child_socket.close()
            watchdog_child_socket.close()
        else:
            type(self).__zk.__get__(self)
        return self

    def wait(self):
        while True:
            try:
                self.__watchdog_socket.recv(1)
                break
            except socket.error as ex:
                if ex.errno == errno.EINTR:
                    continue
                raise

    def _stop(self):
        if self._primary_lock_thread_alive:
            self._stop_primary_lock_thread()
        elif self._primary_lock:
            self._primary_lock.release()
        self.__primary = None
        try:
            del self.__zk
        except AttributeError:
            pass
        return True

    def stop(self):
        self.__logger.info("Stopping contender")
        # noinspection PyBroadException
        try:
            if self.__do_fork and self.__pid:
                with self.__request_lock:
                    try:
                        args = msgpack.dumps(())
                        self.__parent_socket.send("".join((self.Command.STOP, self.SIZE_ST.pack(len(args)), args)))
                    except socket.error:
                        pass
                    try:
                        os.waitpid(self.__pid, 0)
                    except OSError as ex:
                        if ex.errno != errno.ECHILD:
                            raise
                    self.__pid = None
            else:
                self._stop()
        except:
            self.__logger.exception("Error occurred while stopping contender")
            return False
        return True

    def restart(self):
        g = gevent.spawn(self.stop)
        try:
            g.join(self.STOPPING_TIMEOUT)
            stopping_result = g.value
        except gevent.Timeout:
            self.__logger.warning(
                "Contender subprocess is not terminated in %ss", self.STOPPING_TIMEOUT
            )
            stopping_result = False
        if not stopping_result and self.__do_fork and self.__pid:
            self.__logger.warning("Killing contender subprocess with PID %s", self.__pid)
            os.kill(self.__pid, signal.SIGKILL)
        self.start()

    def _request(self, command, *args):
        with self.__request_lock:
            if not self.__pid or not self.__parent_socket:
                raise self.Retry
            try:
                args = msgpack.dumps(args)
                self.__parent_socket.send("".join((command, self.SIZE_ST.pack(len(args)), args)))
                data = self.__parent_socket.recv(self.SIZE_ST.size)
                if not data:
                    self.restart()
                    raise self.Retry
                size = self.SIZE_ST.unpack(data)[0]
                if not size:
                    raise self.Retry
                ret = msgpack.loads(self.__parent_socket.recv(size))
                return ret
            except socket.error as ex:
                if ex.errno in (errno.EPIPE, errno.EBADF):
                    self.restart()
                    raise self.Retry
                raise

    @property
    def _contenders(self):
        try:
            value = self.__zk.get(self.__contenders_path)[0]
            if value:
                return value.split(",")
        except kazoo.exceptions.NoNodeError:
            pass
        return []

    @_contenders.setter
    def _contenders(self, value):
        self.__logger.info("Updating contenders to %r", value)
        self.__zk.set(self.__contenders_path, ",".join(value))

    @property
    def contenders(self):
        value = (
            self._request(self.Command.CONTENDERS)
            if self.__do_fork else
            self._contenders
        )
        if not isinstance(value, list) and not all(isinstance(_, str) for _ in value):
            self.__logger.critical("self._contenders returned wrong value: %r", value)
            self.restart()
            raise self.Retry
        return value

    @contenders.setter
    def contenders(self, value):
        if self.__do_fork:
            self._request(self.Command.CONTENDERS, value)
        else:
            # noinspection PyAttributeOutsideInit
            self._contenders = value

    def _primary_lock_watcher(self):
        try:
            self.__logger.info("Trying to acquire primary lock...")
            if self._primary_lock.acquire():
                self.__logger.info("Primary lock acquired, releasing immediately...")
                self._primary_lock.release()
                self.__logger.info("Primary lock released")
            else:
                self.__logger.info("Cannot acquire primary lock, already locked")
        except (kazoo.exceptions.ConnectionLoss, kazoo.exceptions.CancelledError):
            pass
        except kazoo.exceptions.SessionExpiredError:
            del self.__zk
        finally:
            self._exit(1)
            self.__primary = None
            if self.__watchdog_socket:
                self.__watchdog_socket.close()

    def _primary_watchdog(self):
        self.__logger.info("Primary watchdog started")
        primary_lock = self.__zk.Lock(self.__lock_path, self.__name)
        while self.__primary == self.__name:
            try:
                actual_primary = next(iter(primary_lock.contenders()), None)
            except kazoo.exceptions.KazooException as ex:
                self.__logger.error("Error while requesting lock contenders: %s", ex)
                self.__primary = None
                if self.__watchdog_socket:
                    self.__watchdog_socket.close()
                break
            if actual_primary != self.__name:
                self.__logger.warning("Not a primary anymore, actual primary at host %s", actual_primary)
                self.__primary = None
                if self.__watchdog_socket:
                    self.__watchdog_socket.close()
            else:
                time.sleep(1)
        self.__logger.info("Primary watchdog stopped")

    def _start_primary_lock_thread(self):
        if self._primary_lock_thread_alive:
            return
        self._primary_lock_thread = th.Thread(target=self._primary_lock_watcher)
        self._primary_lock_thread.daemon = True
        self._primary_lock_thread.start()

    @property
    def _primary_lock_thread_alive(self):
        return self._primary_lock_thread and self._primary_lock_thread.is_alive()

    def _stop_primary_lock_thread(self):
        self.__logger.info("Stopping primary lock watcher")
        self._primary_lock.cancel()
        self._primary_lock_thread.join(self.WATCHDOG_JOIN_TIMEOUT)

    def _start_primary_watchdog_thread(self):
        thread = th.Thread(target=self._primary_watchdog)
        thread.daemon = True
        thread.start()

    def _primary(self, contend):
        if contend is None or self.__primary == self.__name:
            return self.__primary
        try:
            type(self).__zk.__get__(self)
            self.__primary = None
            if contend:
                if self._primary_lock_thread_alive:
                    self._stop_primary_lock_thread()
                if self._primary_lock.acquire(blocking=False):
                    if next(iter(self._primary_lock.contenders()), None) == self.__name:
                        self.__logger.info("Primary lock acquired")
                        self.__primary = self.__name
                        self._start_primary_watchdog_thread()
                    else:
                        self.__logger.warning("Primary lock not acquired")
            if not self.__primary:
                self.__logger.info("Getting potential primary from lock contenders")
                self.__primary = common_itertools.progressive_waiter(
                    0, self.__max_interval, self.__timeout,
                    lambda: next(iter(self._primary_lock.contenders()), None),
                    sleep_first=False
                )[0]
                if self.__primary == self.__name:
                    self.__primary = None
                if self.__primary:
                    self.__logger.info("Got potential primary: %s", self.__primary)
                    self._start_primary_lock_thread()
                else:
                    self.__logger.warning("Got no potential primary")
        except (kazoo.exceptions.ConnectionLoss, kazoo.exceptions.CancelledError):
            pass
        except kazoo.exceptions.SessionExpiredError:
            del self.__zk
        return self.__primary

    def primary(self, contend=None):
        value = (
            self._request(self.Command.PRIMARY, contend)
            if self.__do_fork else
            self._primary(contend)
        )
        if not isinstance(value, (type(None), basestring)):
            self.__logger.critical("self._primary(%s) returned wrong value: %r", contend, value)
            self.restart()
            raise self.Retry
        return value
