import logging

import abc
import attr
import enum
import gevent.event
import gevent.lock
import gevent.queue
import six
from boltons import typeutils
from typing import Dict, Callable, final, Optional, Tuple

from awacs.lib import context, zookeeper_client, ctlmanager
from awacs.lib.models.classes import ModelObject
from awacs.model.cache import AwacsCache
from infra.swatlib import metrics
from infra.swatlib.gevent import greenthread, exclusiveservice2
from infra.swatlib.logutil import rndstr
from infra.swatlib.zk.client import LossExceptions


@enum.unique
class ControlFlowEvent(enum.Enum):
    """
    Internal values for controllers' main loop flow
    """
    STOP = typeutils.make_sentinel(b'STOP')  # ctl needs to shut down
    EMPTY = typeutils.make_sentinel(b'EMPTY')  # event queue is empty, polling was terminated by timeout
    READY = typeutils.make_sentinel(b'READY')  # got something from the queue


@attr.s(slots=True, weakref_slot=True, cmp=False)
class CtlGreenThread(six.with_metaclass(abc.ABCMeta, greenthread.GreenThread)):
    """
    A low-level engine, contains common behavior for ModelCtlManager and ModelCtl (see below).
    The main purpose is to provide a non-blocking interface for subscription-based processing.

    CtlGreenThread runs an infinite loop in self._run_main_loop(): it blocks on self.get_event(),
    and then calls self.process_event() on its return value. Events are provided by subscribing self.event_callback()
    to cache events. Main loop is restarted if an exception occurs during the iteration.

    How to use in a subclass:
    1) Make a communication channel on init (event, queue, etc.)
    2) Implement a non-blocking self.event_callback(event) that sends events into channel[1]
    3) Configure subscriptions for cache events in self.start_event_subscriptions(), using self.event_callback[2]
    4) Implement self.get_event() that uses channel[1] to get events.
    """
    # __init__
    cache = attr.ib(type=AwacsCache, kw_only=True)

    # defaults
    sleep_after_exception_timeout = attr.ib(type=int, default=20, kw_only=True)
    sleep_after_exception_timeout_jitter = attr.ib(type=int, default=10, kw_only=True)

    # computed
    name = attr.ib(type=six.text_type, init=False, default=None)
    _log = attr.ib(type=logging.Logger, init=False, default=None)
    _started = attr.ib(type=gevent.event.Event, init=False, default=None)
    _stopped = attr.ib(type=gevent.event.Event, init=False, default=None)

    def __attrs_pre_init__(self):
        # otherwise, greenthread fails to start
        super(CtlGreenThread, self).__init__()

    def __attrs_post_init__(self):
        self._log = logging.getLogger(self.name)
        self._started = gevent.event.Event()
        self._stopped = gevent.event.Event()

    @abc.abstractmethod
    def start_event_subscriptions(self, ctx):
        """
        Runs at the start of the main loop, should use self.event_callback() to subscribe to cache events
        """
        raise NotImplementedError

    @abc.abstractmethod
    def stop_event_subscriptions(self, ctx):
        """
        Runs when the main loop raises an exception, or when CtlGreenThread receives a STOP event
        """
        raise NotImplementedError

    @abc.abstractmethod
    def event_callback(self, full_uid, model, pb=None):
        """
        Called with every event received from subscriptions, must be non-blocking
        """
        raise NotImplementedError

    @abc.abstractmethod
    def get_event(self):
        """
        A blocking method that's used by the main loop to wait for anything to process.
        :returns a protobuf to process, or a ControlFlowEvent
        """
        raise NotImplementedError

    @abc.abstractmethod
    def process_event(self, ctx, event):
        """
        Receives data from self.get_event(), main business logic should be handled here.
        """
        raise NotImplementedError

    @property
    def started(self):
        return self._started.is_set()

    @property
    def stopped(self):
        return self._stopped.is_set()

    @final
    def _run_main_loop(self):
        self._stopped.clear()
        self._started.clear()

        root_ctx = context.BackgroundCtx()
        ctx, cancel = root_ctx.with_cancel()
        self._stopped.rawlink(lambda _: cancel(u'ctl stopped'))

        try:
            self._log.info(u'starting %s...', self.name)
            self.start_event_subscriptions(context.OpCtx(log=self._log, op_id=rndstr()))
            self._started.set()
            self._log.info(u'started %s', self.name)
            while 1:
                # get_event() blocks until it has something interesting to process (or needs to stop)
                event = self.get_event()
                if event is ControlFlowEvent.STOP:
                    break
                ctx = context.OpCtx(log=self._log, op_id=rndstr())
                self.process_event(ctx, event)
        finally:
            self.stop_event_subscriptions(context.OpCtx(log=self._log, op_id=rndstr()))

    # gevent.GreenThread method override
    @final
    def run(self):
        while 1:
            try:
                self._run_main_loop()
            except LossExceptions:
                timeout = self._get_sleep_after_exception_timeout()
                self._log.debug(u'Lost connection to zk, sleeping for %s seconds', timeout)
                gevent.sleep(timeout)
            except (Exception, gevent.Timeout):
                timeout = self._get_sleep_after_exception_timeout()
                self._log.exception(u'Unexpected exception while running, sleeping for %s seconds...', timeout)
                gevent.sleep(timeout)
            else:
                break

    # gevent.GreenThread method override
    @final
    def stop(self, timeout=10, ignore_greenlet_exit=True, log=None, kill_timeout=None):
        self._started.clear()
        self._stopped.set()
        try:
            if self.is_running():
                self.wait(timeout=timeout)
        except gevent.Timeout:
            pass
        finally:
            super(CtlGreenThread, self).stop(ignore_greenlet_exit=ignore_greenlet_exit,
                                             log=log,
                                             kill_timeout=kill_timeout)

    @final
    def _get_sleep_after_exception_timeout(self):
        return ctlmanager.get_jittered_value(self.sleep_after_exception_timeout,
                                             self.sleep_after_exception_timeout_jitter)


@attr.s(slots=True, weakref_slot=True, cmp=False)
class ModelCtlManager(six.with_metaclass(abc.ABCMeta, CtlGreenThread)):
    """
    This is used to get events for ALL objects of a certain class (for example, L7Balancer, or a Namespace),
    and then start/stop/restart individual controllers for each of them.
    Subclass must provide a method to create controllers, and conditions for when to start/stop them.

    All cache events are stored in a queue for eventual processing. Additionally, self.start_event_subscriptions()
    does a full cache scan to start controllers during the manager's start.

    Controllers are wrapped into ExclusiveService, which uses zk locks to ensure that only 1 controller among any
    instances is able to process events.
    """
    # __init__
    zk_client = attr.ib(type=zookeeper_client.ZookeeperClient, kw_only=True)
    allowed_namespace_id_matcher = attr.ib(type=Optional[Callable[[six.text_type], bool]], kw_only=True)

    # must be defined in subclasses
    model = attr.ib(type=ModelObject, init=False)

    # internal
    _queue = attr.ib(type=gevent.queue.Queue, init=False, default=None)
    _ctls = attr.ib(type=dict, init=False, default=None)
    _ctls_lock = attr.ib(type=gevent.lock.RLock, init=False, default=None)

    def __attrs_post_init__(self):
        self.name = u'{}-ctl-manager'.format(self.model.desc.slugified_name)
        self._queue = gevent.queue.Queue()
        self._ctls = {}
        self._ctls_lock = gevent.lock.RLock()
        super(ModelCtlManager, self).__attrs_post_init__()

    @abc.abstractmethod
    def create_ctl(self, full_uid):
        """
        :type full_uid: (six.text_type, six.text_type)
        :rtype: ModelCtl
        """
        raise NotImplementedError

    @abc.abstractmethod
    def should_ctl_be_running(self, full_uid, model, pb=None):
        """
        :type full_uid: (six.text_type, six.text_type)
        :type model: objects.Model
        :type pb: Any
        :rtype bool
        """
        raise NotImplementedError

    @abc.abstractmethod
    def should_restart_ctl(self, full_uid, model, pb=None):
        """
        :type full_uid: (six.text_type, six.text_type)
        :type model: objects.Model
        :type pb: Any
        :rtype bool
        """
        raise NotImplementedError

    @final
    def event_callback(self, full_uid, model, pb=None):
        if self.allowed_namespace_id_matcher is not None and not self.allowed_namespace_id_matcher(full_uid[0]):
            return
        self._queue.put((full_uid, model, pb))

    @final
    def process_event(self, ctx, event):
        if event is ControlFlowEvent.EMPTY:
            return self._process_full_cache(ctx)
        full_uid, model, pb = event
        ctl_should_be_running = self.should_ctl_be_running(full_uid, model, pb)
        if not ctl_should_be_running or self.should_restart_ctl(full_uid, model, pb):
            self._stop_ctl(ctx, full_uid)
        if ctl_should_be_running:
            self._start_ctl(ctx, full_uid)

    @final
    def start_event_subscriptions(self, ctx):
        ctx.log.info(u'Starting all ctl singletons...')
        self._process_full_cache(ctx)
        ctx.log.info(u'Started all ctl singletons...')

        self.cache.subscribe_to_updates(self.event_callback, (self.model.desc.zk_prefix,), self.model)
        self.cache.subscribe_to_removals(self.event_callback, (self.model.desc.zk_prefix,), self.model)

    @final
    def _process_full_cache(self, ctx):
        with self._ctls_lock:
            for pb in self.model.cache.list():
                full_uid = (pb.meta.namespace_id, pb.meta.id)
                if self.should_ctl_be_running(full_uid, self.model, pb):
                    self._start_ctl(ctx, full_uid)

    @final
    def stop_event_subscriptions(self, ctx):
        self.cache.unsubscribe_from_updates(self.event_callback, (self.model.desc.zk_prefix,), self.model)
        self.cache.unsubscribe_from_removals(self.event_callback, (self.model.desc.zk_prefix,), self.model)

        ctx.log.info(u'Stopping all ctl singletons...')
        with self._ctls_lock:
            for ctl_id in list(six.iterkeys(self._ctls)):  # can't use six.iterkeys because we modify the dict in self._stop_ctl
                self._stop_ctl(ctx, ctl_id)
            self._ctls.clear()
        self._queue = gevent.queue.Queue()
        ctx.log.info(u'Stopped all ctl singletons')

    @final
    def _start_ctl(self, ctx, full_uid):
        """
        :type ctx: context.OpCtx
        :type full_uid: (six.text_type, six.text_type)
        """
        name = u'{}("{}:{}") ctl'.format(self.model.desc.slugified_name, *full_uid)
        with self._ctls_lock:
            if full_uid in self._ctls:
                return
            ctx.log.info(u'Starting singleton for %s...', name)
            ctl = self._create_ctl_singleton(full_uid)
            if ctl:
                self._ctls[full_uid] = ctl
        if ctl:
            ctl.start()
            ctx.log.info(u'Started singleton for %s', name)

    @final
    def _create_ctl_singleton(self, full_uid):
        """
        :type full_uid: (six.text_type, six.text_type)
        :rtype: ModelCtl
        """
        ctl = self.create_ctl(full_uid)
        if not ctl:
            return
        if six.PY3:
            acquire_timeout_strategy = ctlmanager.default_acquire_timeout_strategy().__next__
        else:
            acquire_timeout_strategy = ctlmanager.default_acquire_timeout_strategy().next
        if six.PY3:
            standoff_strategy = ctlmanager.default_standoff_strategy().__next__
        else:
            standoff_strategy = ctlmanager.default_standoff_strategy().next
        return exclusiveservice2.ExclusiveService(
            coord=self.zk_client,
            name=ctl.name,
            runnable=ctl,
            acquire_timeout_strategy=acquire_timeout_strategy,
            standoff_strategy=standoff_strategy,
            lock_release_timeout=ctlmanager.get_lock_release_timeout(),
            greenthread_aware_mode=True,
            metrics_registry=metrics.ROOT_REGISTRY,
            disable_per_exclusive_service_metrics=True,
        )

    @final
    def _stop_ctl(self, ctx, full_uid):
        """
        :type ctx: context.OpCtx
        :type full_uid: (six.text_type, six.text_type)
        """
        name = u'{}("{}:{}") ctl'.format(self.model.desc.slugified_name, *full_uid)
        with self._ctls_lock:
            if full_uid not in self._ctls:
                return
            ctx.log.info(u'Stopping singleton for %s...', name)
            ctl = self._ctls[full_uid]
            del self._ctls[full_uid]
        ctl.stop()
        ctx.log.info(u'Stopped singleton for %s...', name)

    @final
    def _wait_event(self):
        try:
            self._queue.peek()
        except gevent.queue.Empty:
            pass

    @final
    def get_event(self):
        peek = gevent.spawn(self._wait_event)
        gevent.wait([self._stopped, peek], count=1)
        if self._stopped.is_set():
            self._queue.put(ControlFlowEvent.STOP)  # to cause `peek` greenlet to stop
            return ControlFlowEvent.STOP
        elif self._queue.empty():
            return ControlFlowEvent.EMPTY
        else:
            return self._queue.get()


@attr.s(slots=True, weakref_slot=True, frozen=True)
class Subscription(object):
    # if True, trigger processing right away. Otherwise, processing will be delayed for "ModelCtl.process_interval" sec
    immediate = attr.ib(type=bool, default=False)

    # if True, only receive events for this ModelCtl's exact uid. Otherwise, receive events for all model objects
    check_uid = attr.ib(type=bool, default=True)

    # if True, receive events that create/update records in cache
    watch_updates = attr.ib(type=bool, default=True)

    # if True, receive events that remove records from cache
    watch_removals = attr.ib(type=bool, default=False)


@attr.s(slots=True, weakref_slot=True, cmp=False)
class ModelCtl(six.with_metaclass(abc.ABCMeta, CtlGreenThread)):
    """
    A controller for an individual object, identified by UID.

    Algorithm:
    1) Controller subscribes to cache events according to "self.subscriptions";
    2) When a cache event is received and is accepted by self.should_process(), controller either sets a processing flag
       for [3] (if the event's subscription has immediate=True), or starts a timer for "self.process" seconds
       that will trigger [3], or does nothing if a timer is already started.
    3) Controller runs an infinite loop, waiting for either a cache event, or for "self.force_process_interval" seconds
    4) When the processing flag in [3] is set, controller clears it and runs self.process() - note that it doesn't
       give it an event, instead self.process() must acquire all objects that it needs for processing.
    """
    # __init__
    full_uid = attr.ib(type=Tuple[six.text_type], kw_only=True)
    order = attr.ib(type=bool, kw_only=True, default=False)
    process_interval = attr.ib(type=int, default=30, kw_only=True)

    # should be defined in subclasses, main object processed by this controller
    model = attr.ib(type=ModelObject, init=False)

    # should be defined in subclasses, all objects relevant for processing ("model" should be included here as well)
    subscriptions = attr.ib(type=Dict[ModelObject, Subscription], init=False)

    # defaults
    force_process_interval = attr.ib(type=int, default=60, kw_only=True)
    force_process_interval_jitter = attr.ib(type=int, default=10, kw_only=True)

    # internal
    _ready_to_process = attr.ib(type=gevent.event.Event, default=None, init=False)
    _processing_waiter = attr.ib(type=gevent.Greenlet, default=None, init=False)

    def __attrs_post_init__(self):
        self.name = u'{}{}-ctl("{}:{}")'.format(self.model.desc.slugified_name, '-order' if self.order else '', *self.full_uid)
        self._ready_to_process = gevent.event.Event()
        self._processing_waiter = None
        super(ModelCtl, self).__attrs_post_init__()

    @abc.abstractmethod
    def should_process(self, full_uid, model, pb):
        raise NotImplementedError

    @abc.abstractmethod
    def process(self, ctx):
        raise NotImplementedError

    @final
    def get_event(self):
        gevent.wait([self._stopped, self._ready_to_process], count=1, timeout=self._get_force_interval())
        if self.stopped:
            return ControlFlowEvent.STOP
        else:
            return ControlFlowEvent.READY

    @final
    def process_event(self, ctx, event):
        # event is always PollingResult.READY here, we don't really need it
        self._ready_to_process.clear()
        self.process(ctx)

    @final
    def start_event_subscriptions(self, ctx):
        for model, s in six.iteritems(self.subscriptions):
            zk_path_chunks = [model.desc.zk_prefix]
            if s.check_uid:
                zk_path_chunks.extend(self.full_uid)
            if s.watch_updates:
                self.cache.subscribe_to_updates(self.event_callback, tuple(zk_path_chunks), model)
            if s.watch_removals:
                self.cache.subscribe_to_removals(self.event_callback, tuple(zk_path_chunks), model)
        self._ready_to_process.set()

    @final
    def stop_event_subscriptions(self, ctx):
        for model, s in six.iteritems(self.subscriptions):
            zk_path_chunks = [model.desc.zk_prefix]
            if s.check_uid:
                zk_path_chunks.extend(self.full_uid)
            if s.watch_updates:
                self.cache.unsubscribe_from_updates(self.event_callback, tuple(zk_path_chunks), model)
            if s.watch_removals:
                self.cache.unsubscribe_from_removals(self.event_callback, tuple(zk_path_chunks), model)

    @final
    def event_callback(self, full_uid, model, pb=None):
        if not self.should_process(full_uid, model, pb):
            return
        if self.subscriptions[model].immediate:
            # trigger immediate processing
            return self._finalize_processing_waiter()
        elif self._processing_waiter is None:
            # otherwise, delay processing to collect more updates
            self._processing_waiter = gevent.spawn_later(self.process_interval, self._finalize_processing_waiter)
        else:
            # timer is already set, just wait
            pass

    @final
    def _finalize_processing_waiter(self):
        self._processing_waiter = None
        self._ready_to_process.set()

    @final
    def _get_force_interval(self):
        return ctlmanager.get_jittered_value(self.force_process_interval, self.force_process_interval_jitter)
