# coding: utf-8
import os
import logging
import random
import time

import enum
import gevent
import gevent.lock
import gevent.queue
import gevent.event
import gevent.threading
from kazoo.exceptions import (LockTimeout,
                              KazooException,
                              ConnectionLossException,
                              OperationTimeoutException,
                              ConnectionClosedError,
                              SessionExpiredError)
from sepelib.util import retry
from sepelib.gevent import greenthread
from sepelib.util.exc.format import format_exc
from kazoo.client import KazooState

from infra.swatlib.logutil import rndstr
from .greenthread import GreenThread
from .util import force_kill_greenlet


_STANDOFF_TIMEOUT_EXCEEDED = object()
_DISCONNECTED = object()

_DEFAULT_STANDOFF_TIMEOUT = 12 * 3600


LossExceptions = (ConnectionLossException,
                  OperationTimeoutException,
                  ConnectionClosedError,
                  SessionExpiredError)


class LoggerAdapter(logging.LoggerAdapter):
    def __init__(self, log, op_id):
        """
        :type log: logging.Logger
        :type op_id: six.text_type
        """
        logging.LoggerAdapter.__init__(self, log, {'op_id': op_id})

    def warn(self, msg, *args, **kwargs):
        return self.warning(msg, *args, **kwargs)


class SessionIdFilter(object):
    def __init__(self, coord):
        """
        :type coord: infra.swatlib.zookeeper_client.ZookeeperClient
        """
        self.coord = coord

    def filter(self, record):
        client_id = self.coord.client.client_id
        session_id = client_id and client_id[0] or 0
        record.name += ':' + hex(session_id)
        return True


class RestartPolicy(enum.Enum):
    DO_NOT_RESTART = 1
    RESTART_ON_EXCEPTION = 2
    ALWAYS_RESTART = 3


class Runnable(object):
    def run(self):
        raise NotImplementedError


def set_on_disconnect(state, event):
    if state != KazooState.CONNECTED:
        event.set()


def cancel_on_disconnect(state, lock):
    if state != KazooState.CONNECTED:
        lock.cancel()


def random_standoff_strategy():
    """
    Copy-pasted from nanny/lib/singleton_service.py

    This is a tentative attempt to work around the fact, that after new release
    all singleton services are running in one instance because of our deployment procedure.
    """
    return random.randint(_DEFAULT_STANDOFF_TIMEOUT // 2, _DEFAULT_STANDOFF_TIMEOUT)


class ExclusiveService(greenthread.GreenThread):
    """
    A service than runs under a Zookeeper lock.
    Designed to be used as a wrapper for unaware service.
    """
    # it's a good practice, mentioned in google chubby papers
    # to give up control once in a while
    DEFAULT_STANDOFF_STRATEGY = lambda: _DEFAULT_STANDOFF_TIMEOUT
    DEFAULT_ACQUIRE_TIMEOUT_STRATEGY = lambda: None
    DEFAULT_KILL_TIMEOUT = 1.0
    SLEEP_AFTER_UNEXPECTED_EXCEPTION_TIMEOUT = 0.5
    DEFAULT_LOCK_RELEASE_TIMEOUT = 30

    ZK_PATH = '/exclusive_services'

    def __init__(self, coord, name, runnable,
                 acquire_timeout_strategy=DEFAULT_ACQUIRE_TIMEOUT_STRATEGY,
                 standoff_strategy=DEFAULT_STANDOFF_STRATEGY,
                 restart_policy=RestartPolicy.DO_NOT_RESTART,
                 kill_timeout=DEFAULT_KILL_TIMEOUT,
                 lock_release_timeout=DEFAULT_LOCK_RELEASE_TIMEOUT,
                 disable_soft_lock_release_timeouts=False,
                 greenthread_aware_mode=False,
                 metrics_registry=None,
                 disable_per_exclusive_service_metrics=False):
        """
        :param coord: zookeeper client instance
        :type coord: infra.swatlib.zookeeper_client.ZookeeperClient
        :type runnable: Runnable or callable
        :param callable standoff_strategy:
            a callable that returns current standoff timeout,
            MUST NOT raise any exceptions.
        :param callable acquire_timeout_strategy:
            a callable that returns amount of seconds we wait before giving up on acquiring the lock,
            MUST NOT raise any exceptions.
        :param RestartPolicy restart_policy: RestartPolicy.*
        :param callable kill_timeout:
            number of seconds between service kill attempts, passed as is to force_kill_greenlet function
        """
        super(ExclusiveService, self).__init__()

        self._greenthread = None
        self._service = None
        self._run = None

        if greenthread_aware_mode:
            if isinstance(runnable, GreenThread):
                self._greenthread = runnable
            else:
                raise RuntimeError('greenthread_aware_mode only works with GreenThread runnables')
        else:
            if callable(runnable):
                self._run = runnable
            else:
                self._run = runnable.run

        self._coord = coord

        self.name = 'exclusive({})'.format(name)

        self._log, self._session_id_filter = self._setup_log()

        self._lock = self._coord.lock(os.path.join(self.ZK_PATH, name))
        self._lock.debug_mode = True
        self._lock.cleanup_acquire_on_greenlet_exit = True

        self._standoff_strategy = standoff_strategy
        self._acquire_timeout_strategy = acquire_timeout_strategy
        self._restart_policy = restart_policy
        self._kill_timeout = kill_timeout
        self._lock_release_timeout = lock_release_timeout
        self._disable_soft_lock_release_timeouts = disable_soft_lock_release_timeouts

        self._stopped = False
        self._stopping_lock = gevent.threading.Lock()

        self._metrics_registry = None
        self._lock_release_soft_timeouts_counter = None
        self._lock_release_hard_timeouts_counter = None
        if metrics_registry is not None:
            # create rarely increased counters in __init__ to avoid n/a in YASM signals,
            # see https://st.yandex-team.ru/SWAT-7046#5ef0b17ab4b91b2222c268fe for details
            common_metrics_registry = metrics_registry.path('exclusive-services')
            self._lock_release_soft_timeouts_counter = common_metrics_registry.get_counter(
                'lock-release-soft-timeouts')
            self._lock_release_hard_timeouts_counter = common_metrics_registry.get_counter(
                'lock-release-hard-timeouts')

            if not disable_per_exclusive_service_metrics:
                self._metrics_registry = metrics_registry.path('exclusive-service', name)

    @property
    def controller(self):
        return self._greenthread

    def _setup_log(self):
        log = logging.getLogger(self.name)

        # just in case we are reusing this logger from previously existed exclusive service...
        session_id_filter = None
        for filter_ in log.filters:
            if isinstance(filter_, SessionIdFilter):
                session_id_filter = filter_
                break

        if not session_id_filter:
            session_id_filter = SessionIdFilter(self._coord)
            log.addFilter(session_id_filter)

        return log, session_id_filter

    def _wait_disconnect(self, chan, standoff_timeout=None):
        event = gevent.event.Event()
        listener = lambda state: set_on_disconnect(state, event)
        self._coord.add_listener(listener)
        try:
            event.wait(timeout=standoff_timeout)
        finally:
            self._coord.remove_listener(listener)
        if event.is_set():
            chan.put(_DISCONNECTED)
        else:
            chan.put(_STANDOFF_TIMEOUT_EXCEEDED)

    def _wait_service(self, chan):
        try:
            if self._greenthread:
                rv = self._greenthread.get()
            else:
                rv = self._service.get()
        except Exception as e:
            chan.put(e)
        else:
            chan.put(rv)

    def _wait_disconnect_or_service(self, standoff_timeout=None):
        chan = gevent.queue.Queue()
        gs = (
            gevent.spawn(self._wait_disconnect, chan=chan, standoff_timeout=standoff_timeout),
            gevent.spawn(self._wait_service, chan=chan),
        )
        try:
            return chan.get()
        finally:
            for g in gs:
                force_kill_greenlet(g, kill_timeout=self._kill_timeout)

    def _stop_service(self, log=None):
        log.info("stopping service...")
        if self._greenthread:
            log.info("stopping _greenthread with timeout of %s...", self._kill_timeout)
            was_running = self._greenthread.is_running()
            self._greenthread.stop(ignore_greenlet_exit=True, log=log, kill_timeout=self._kill_timeout)
        elif self._service:
            log.info("stopping _service with timeout of %s...", self._kill_timeout)
            was_running = not self._service.ready()
            force_kill_greenlet(self._service, kill_timeout=self._kill_timeout, log=log)
        else:
            was_running = False
        log.info("stopped service (was running: %s)", was_running)
        return was_running

    def run(self):
        after_unexpected_exception = False
        while 1:
            try:
                self.do_run(after_unexpected_exception=after_unexpected_exception)
            except (Exception, gevent.Timeout):
                self._log.exception('unexpected exception in do_run()')
                after_unexpected_exception = True
                time.sleep(self.SLEEP_AFTER_UNEXPECTED_EXCEPTION_TIMEOUT)
                continue
            else:
                break

    def _set_acquired_locks_gauge(self, v):
        if self._metrics_registry:
            self._metrics_registry.get_summable_gauge('acquired_locks').set(v)

    def _inc_lock_release_soft_timeouts_counter(self):
        if self._lock_release_soft_timeouts_counter is not None:
            self._lock_release_soft_timeouts_counter.inc(1)

    def _inc_lock_release_hard_timeouts_counter(self):
        if self._lock_release_hard_timeouts_counter is not None:
            self._lock_release_hard_timeouts_counter.inc(1)

    def do_run(self, after_unexpected_exception=False):
        op_id = rndstr()
        log = LoggerAdapter(log=self._log, op_id=op_id)

        sleeper = retry.RetrySleeper(max_delay=5)

        self._stopped = False

        correctly_released_lock = True
        listener = lambda state: cancel_on_disconnect(state, self._lock)
        while not self._stopped:
            acquire_timeout = self._acquire_timeout_strategy()

            log.info('acquiring lock with timeout of %s seconds...', acquire_timeout)
            if self._lock.is_acquired:
                assert after_unexpected_exception or not correctly_released_lock
                log.warn('lock has not been correctly released previously, trying again...')
                # If we call acquire() on a lock that "considers" itself already acquired,
                # it will hang indefinitely waiting until its is_acquired becomes False,
                # which will never happen in our case.
                # So we make sure is_acquired is false before continuing.
                release_sleeper = retry.RetrySleeper(max_delay=5)
                while not correctly_released_lock:
                    correctly_released_lock = self._release_lock(log=log)
                    release_sleeper.increment()
                log.info('lock is correctly released')

            self._coord.add_listener(listener)
            try:
                is_acquired = self._lock.acquire(timeout=acquire_timeout)
            except LockTimeout:
                log.info('failed to acquire lock within timeout of {} seconds'.format(acquire_timeout))
            except KazooException as e:
                log.info(format_exc('failed to acquire lock', e))
                sleeper.increment()
            else:
                self._set_acquired_locks_gauge(int(is_acquired))
                if not is_acquired:
                    log.warn('lock.acquire() returned false')
                    sleeper.increment()
                else:
                    # we locked, yey!
                    # now start service and wait for session state change
                    log.info('lock acquired: %s', self._lock.node)
                    sleeper.reset()
                    standoff_timeout = self._standoff_strategy()
                    log.info("became singleton - "
                             "starting service with standoff timeout of {} seconds...".format(standoff_timeout))
                    with self._stopping_lock:
                        if self._stopped:
                            log.info('acquired lock for stopped service')
                            self._release_lock(log=log)
                            return
                        if self._greenthread:
                            self._greenthread.start()
                        else:
                            self._service = gevent.Greenlet(self._run)
                            self._service.start()
                    res = self._wait_disconnect_or_service(standoff_timeout=standoff_timeout)
                    if res is _DISCONNECTED:
                        log.info('disconnected - stopping service')
                        self._stop_service(log=log)
                        correctly_released_lock = self._release_lock(log=log)
                        log.info('service stopped')
                    elif res is _STANDOFF_TIMEOUT_EXCEEDED:
                        log.info("was leading for too long (more than {} seconds), "
                                 "stand off - stopping service".format(standoff_timeout))
                        # First we stop service, then release lock
                        # Thus we encounter less races
                        self._stop_service(log=log)
                        correctly_released_lock = self._release_lock(log=log)
                        log.info("service stopped")
                    elif isinstance(res, Exception):
                        log.info("service raised an exception")
                        correctly_released_lock = self._release_lock(log=log)
                        if self._restart_policy not in (RestartPolicy.ALWAYS_RESTART,
                                                        RestartPolicy.RESTART_ON_EXCEPTION):
                            break
                        log.info("restarting according to the restart policy...")
                    else:
                        log.info("service returned {!r}".format(res))
                        correctly_released_lock = self._release_lock(log=log)
                        if self._restart_policy != RestartPolicy.ALWAYS_RESTART:
                            break
                        log.info("restarting according to the restart policy...")

                    # sleep for sometime to let someone take leadership
                    gevent.sleep(2)
            finally:
                self._coord.remove_listener(listener)
            after_unexpected_exception = False

    @staticmethod
    def _handle_failed_lock_release(log):
        log.warn("committing suicide with SIGKILL in 10 seconds")
        # sleep to give yasm metrics and sentry events some time to be "sent"
        gevent.sleep(10)
        os.kill(os.getpid(), gevent.signal.SIGKILL)

    def _release_lock(self, log):
        log.info("releasing lock: %s", self._lock.node)

        timeout = self._lock_release_timeout
        try:
            is_released = self._lock.release(timeout=timeout, log=log)
        except gevent.GreenletExit:
            log.warn("releasing lock failed with GreenletExit")
            raise
        except gevent.Timeout:
            if self._disable_soft_lock_release_timeouts:
                log.warn("first attempt to release lock exceeded timeout of %d seconds, "
                         "not trying again (soft timeouts are disabled)...", timeout)
                self._inc_lock_release_hard_timeouts_counter()
                self._handle_failed_lock_release(log)
                return False
            else:
                self._inc_lock_release_soft_timeouts_counter()
                log.warn("first attempt to release lock exceeded timeout of %d seconds, trying again...", timeout)
                try:
                    is_released = self._lock.release(timeout=timeout, log=log)
                except gevent.GreenletExit:
                    log.warn("releasing lock failed with GreenletExit")
                    raise
                except gevent.Timeout:
                    self._inc_lock_release_hard_timeouts_counter()
                    log.warn("second attempt to release lock exceeded timeout of %d seconds", timeout)
                    self._handle_failed_lock_release(log)
                    return False
                except LossExceptions as e:
                    log.info("releasing lock failed with %s", e.__class__.__name__)
                    return False
                else:
                    log.info("is_released: %s", is_released)
                    self._set_acquired_locks_gauge(0 if is_released else 1)
                    return is_released
        except LossExceptions as e:
            log.info("releasing lock failed with %s", e.__class__.__name__)
            return False
        else:
            log.info("is_released: %s", is_released)
            self._set_acquired_locks_gauge(0 if is_released else 1)
            return is_released

    def stop(self):
        log = LoggerAdapter(log=self._log, op_id=rndstr())

        log.info("stopping exclusive service...")
        self._stopped = True

        log.info("cancelling lock...")
        if not self._lock.is_acquired:
            cancelled = gevent.event.Event()
            cancel_rv = self._lock.cancel(ev=cancelled)
            timeout = 10
            if cancelled.wait(timeout=timeout):
                log.info("successfully cancelled, is_acquired: %s, cancel_rv: %s",
                         self._lock.is_acquired, cancel_rv)
            else:
                log.warn("failed to cancel lock in %d seconds, is_acquired: %s, cancel_rv: %s!",
                         timeout, self._lock.is_acquired, cancel_rv)
        with self._stopping_lock:
            was_running = self._stop_service(log=log)
        is_released = self._release_lock(log=log)
        if is_released is not was_running:
            log.warn("is_released: %s, was_running: %s, how come?", is_released, was_running)
        super(ExclusiveService, self).stop()
        log.info("stopped exclusive service")


class SingletonObserver(object):
    """
    Utility to introspect singletons from outside
    """

    def __init__(self, zookeeper_client):
        self._zookeeper_client = zookeeper_client

    def list_services(self):
        """
        Get a list of services using singleton service mechanism.
        """
        return self._zookeeper_client.get_children(ExclusiveService.ZK_PATH)

    def list_service_instances(self, service):
        """
        Get an ordered list of instances competing to become master.
        """
        lock = self._zookeeper_client.lock(os.path.join(ExclusiveService.ZK_PATH, service))
        return lock.contenders()
