# coding: utf-8
import logging
import random

import abc
import gevent.event
import gevent.lock
import gevent.queue
import os
import six
import uhashring
from sepelib.core import config as appconfig

from awacs.lib import context
from awacs.model.storage_modern import PARTIES_NODE_ZK_PREFIX
from awacs.model.events import PartyMemberUpdate, PartyMemberRemove, TreeEvent
from infra.swatlib import metrics
from infra.swatlib.gevent import greenthread, exclusiveservice2
from infra.swatlib.gevent.singletonparty import SingletonParty
from infra.swatlib.logutil import rndstr
from infra.swatlib.zk import treecache
from infra.swatlib.zk.client import LossExceptions


UNEXPECTED_EXCEPTIONS = (Exception, gevent.Timeout)

MINUTE = 60
HOUR = 3600
HALF_A_MINUTE = 30
ONE_MINUTE = 1 * MINUTE
FIVE_MINUTES = 5 * MINUTE
TEN_MINUTES = 10 * MINUTE
TWELVE_HOURS = 12 * HOUR

STOPPED = object()
EMPTY = object()


def default_standoff_strategy():
    if random.random() > .6:
        yield random.randint(FIVE_MINUTES, 2 * TEN_MINUTES)
    else:
        yield TWELVE_HOURS
    while 1:
        yield TWELVE_HOURS


def default_acquire_timeout_strategy():
    while 1:
        r = random.random()
        if r < 0.1:
            yield HALF_A_MINUTE
        elif r < 0.3:
            yield random.randint(ONE_MINUTE, FIVE_MINUTES)
        elif r < 0.5:
            yield random.randint(FIVE_MINUTES, TEN_MINUTES)
        elif r < 0.7:
            yield random.randint(TEN_MINUTES, 3 * TEN_MINUTES)
        else:
            yield TWELVE_HOURS


def get_lock_release_timeout():
    return appconfig.get_value('run.exclusive_service_lock_release_timeout',
                               default=exclusiveservice2.ExclusiveService.DEFAULT_LOCK_RELEASE_TIMEOUT)


def get_lock_disable_soft_timeouts():
    return appconfig.get_value('run.exclusive_service_disable_soft_timeouts', default=False)


def get_jittered_value(value, jitter):
    return value + random.randint(-jitter, jitter)


class CtlManager(greenthread.GreenThread):
    SLEEP_AFTER_EXCEPTION_TIMEOUT = 20  # seconds
    SLEEP_AFTER_EXCEPTION_TIMEOUT_JITTER = 10

    def __init__(self, coord, cache, name, starting_events, stopping_events):
        """
        :type coord: awacs.lib.zookeeper_client.ZookeeperClient
        :type cache: awacs.model.cache.AwacsCache
        :type name: six.text_type
        :type starting_events: tuple[type, ...]
        :type stopping_events: tuple[type, ...]
        """
        super(CtlManager, self).__init__()
        self.name = name
        self._log = logging.getLogger(name)
        self._coord = coord
        self._cache = cache
        self._events_queue = gevent.queue.Queue()
        self._ctls = {}
        self._starting_events = starting_events
        self._stopping_events = stopping_events
        self._stopped = gevent.event.Event()

        self._ctl_start_stop_lock = gevent.lock.RLock()

    def _get_sleep_after_exception_timeout(self):
        return get_jittered_value(self.SLEEP_AFTER_EXCEPTION_TIMEOUT, self.SLEEP_AFTER_EXCEPTION_TIMEOUT_JITTER)

    def _get_standoff_strategy(self, ctl_id):
        if six.PY3:
            return default_standoff_strategy().__next__
        else:
            return default_standoff_strategy().next

    def _get_acquire_timeout_strategy(self, ctl_id):
        if six.PY3:
            return default_acquire_timeout_strategy().__next__
        else:
            return default_acquire_timeout_strategy().next

    def _get_ctl_human_readable_name(self, ctl_id):
        raise NotImplementedError

    def _list_all_assigned_ctl_ids(self):
        return self._list_all_ctl_ids()

    def _list_all_ctl_ids(self):
        raise NotImplementedError

    def _create_ctl(self, ctl_id):
        raise NotImplementedError

    def _get_ctl_id_from_event(self, event):
        raise NotImplementedError

    def _create_ctl_singleton(self, ctl_id):
        ctl = self._create_ctl(ctl_id)
        if not ctl:
            return
        return exclusiveservice2.ExclusiveService(
            coord=self._coord,
            name=ctl.name,
            runnable=ctl,
            acquire_timeout_strategy=self._get_acquire_timeout_strategy(ctl_id),
            standoff_strategy=self._get_standoff_strategy(ctl_id),
            lock_release_timeout=get_lock_release_timeout(),
            disable_soft_lock_release_timeouts=get_lock_disable_soft_timeouts(),
            greenthread_aware_mode=True,
            metrics_registry=metrics.ROOT_REGISTRY,
            disable_per_exclusive_service_metrics=True,
        )

    def _start_ctl(self, ctl_id):
        with self._ctl_start_stop_lock:
            name = self._get_ctl_human_readable_name(ctl_id)
            if ctl_id in self._ctls:
                self._log.info('Singleton for {} is already running'.format(name))
            else:
                self._log.info('Starting singleton for {}...'.format(name))
                ctl = self._create_ctl_singleton(ctl_id)
                if ctl:
                    self._ctls[ctl_id] = ctl
                    ctl.start()
                    self._log.info('Started singleton for {}'.format(name))

    def _stop_ctl(self, ctl_id):
        with self._ctl_start_stop_lock:
            name = self._get_ctl_human_readable_name(ctl_id)
            if ctl_id not in self._ctls:
                self._log.info('Singleton for {} is not running'.format(name))
            else:
                self._log.info('Stopping singleton for {}...'.format(name))
                self._ctls[ctl_id].stop()
                del self._ctls[ctl_id]
                self._log.info('Stopped singleton for {}...'.format(name))

    def _start(self):
        self._log.info('Starting all ctl singletons...')
        self._cache.bind(self._callback)
        with self._ctl_start_stop_lock:
            for ctl_id in self._list_all_assigned_ctl_ids():
                self._start_ctl(ctl_id)
        self._log.info('Started all ctl singletons...')

    def _stop(self):
        self._cache.unbind(self._callback)
        self._log.info('Stopping all ctl singletons...')
        with self._ctl_start_stop_lock:
            for ctl_id in list(six.iterkeys(self._ctls)):
                self._stop_ctl(ctl_id)
            self._ctls.clear()
        self._events_queue = gevent.queue.Queue()
        self._log.info('Stopped all ctl singletons')

    def _is_starting_event(self, event):
        return isinstance(event, self._starting_events)

    def _is_stopping_event(self, event):
        return isinstance(event, self._stopping_events)

    def _process_event(self, event):
        ctl_id = self._get_ctl_id_from_event(event)
        if self._is_starting_event(event):
            self._start_ctl(ctl_id)
        elif self._is_stopping_event(event):
            self._stop_ctl(ctl_id)

    def _callback(self, event):
        if self._is_starting_event(event) or self._is_stopping_event(event):
            self._events_queue.put(event)

    def _get_event(self):
        def _wait_event():
            try:
                self._events_queue.peek()
            except gevent.queue.Empty:
                pass

        peek = gevent.spawn(_wait_event)
        gevent.wait([self._stopped, peek], count=1)
        if self._stopped.is_set():
            self._events_queue.put(STOPPED)  # to cause `peek` greenlet to stop
            return STOPPED
        else:
            return self._events_queue.get()

    def _run(self):
        self._log.info('Running...')
        try:
            self._start()
            while 1:
                event = self._get_event()
                if event is STOPPED:
                    break
                else:
                    self._process_event(event)
        finally:
            self._stop()

    def stop(self, timeout=10, ignore_greenlet_exit=True, log=None, kill_timeout=None):
        self._stopped.set()
        try:
            if self.is_running():
                self.wait(timeout=timeout)
        except gevent.Timeout:
            pass
        finally:
            super(CtlManager, self).stop(ignore_greenlet_exit=ignore_greenlet_exit, log=log, kill_timeout=kill_timeout)

    def run(self):
        self._stopped.clear()
        while 1:
            try:
                self._run()
            except LossExceptions:
                timeout = self._get_sleep_after_exception_timeout()
                self._log.debug(u'Lost connection to zk, sleeping for {} seconds'.format(timeout))
                gevent.sleep(timeout)
            except UNEXPECTED_EXCEPTIONS:
                timeout = self._get_sleep_after_exception_timeout()
                self._log.exception(u'Unexpected exception while running, '
                                    u'sleeping for {} seconds...'.format(timeout))
                gevent.sleep(timeout)
            else:
                break


class CtlManagerV2(six.with_metaclass(abc.ABCMeta, greenthread.GreenThread)):
    SLEEP_AFTER_EXCEPTION_TIMEOUT = 20  # seconds
    SLEEP_AFTER_EXCEPTION_TIMEOUT_JITTER = 10

    def __init__(self, coord, cache, entity_name, subscribed_events, allowed_namespace_id_matcher=None):
        """
        :type coord: awacs.lib.zookeeper_client.ZookeeperClient
        :type cache: awacs.model.cache.AwacsCache
        :type entity_name: six.text_type
        :type subscribed_events: list[events.*]
        :type allowed_namespace_id_matcher: callable | None
        """
        super(CtlManagerV2, self).__init__()
        self.name = '{}-ctl-manager'.format(entity_name)
        self._coord = coord
        self._cache = cache
        self._entity_name = entity_name
        self._subscribed_events = subscribed_events
        self._allowed_namespace_id_matcher = allowed_namespace_id_matcher
        self._log = logging.getLogger(self.name)
        self._events_queue = gevent.queue.Queue()
        self._stopped = gevent.event.Event()
        self._ctls = {}
        self._ctl_start_stop_lock = gevent.lock.RLock()

    def _get_sleep_after_exception_timeout(self):
        return get_jittered_value(self.SLEEP_AFTER_EXCEPTION_TIMEOUT, self.SLEEP_AFTER_EXCEPTION_TIMEOUT_JITTER)

    @abc.abstractmethod
    def _yield_starting_events(self):
        """
        :rtype events.*
        """
        raise NotImplementedError

    @abc.abstractmethod
    def _create_ctl(self, ctl_id):
        """
        :type ctl_id: (six.text_type, six.text_type)
        """
        raise NotImplementedError

    @abc.abstractmethod
    def _should_ctl_be_running(self, event):
        """
        :param event: guaranteed to be one of self._subscribed_events
        :type event: events.*
        :rtype bool
        """
        raise NotImplementedError

    def _should_restart_ctl(self, ctl_id):
        """
        :type ctl_id: (six.text_type, six.text_type)
        :rtype bool
        """
        pass

    def _process_event(self, event, ctx=None):
        """
        :type ctx: Optional[context.OpCtx]
        :type event: events.*
        """
        assert isinstance(event, self._subscribed_events)
        ctl_id = self._get_ctl_id_from_event(event)
        if not self._is_allowed_namespace(self._get_namespace_id_from_ctl_id(ctl_id)):
            return
        ctx = ctx or context.OpCtx(log=self._log, op_id=rndstr())
        ctl_should_be_running = self._should_ctl_be_running(event)
        if not ctl_should_be_running or self._should_restart_ctl(ctl_id):
            self._stop_ctl(ctx, ctl_id)
        if ctl_should_be_running:
            self._start_ctl(ctx, ctl_id)

    def _callback(self, event):
        """
        :type event: events.*
        """
        self._events_queue.put(event)

    def _is_allowed_namespace(self, namespace_id):
        """
        :type namespace_id: six.text_type
        :rtype: bool
        """
        return self._allowed_namespace_id_matcher is None or self._allowed_namespace_id_matcher(namespace_id)

    def _get_ctl_id_from_event(self, event):
        """
        :type event: events.*
        :rtype: (six.text_type, six.text_type)
        """
        ctl_id = tuple(event.path.strip('/').split('/'))
        if len(ctl_id) != 2 or not ctl_id[0] or not ctl_id[1]:
            raise AssertionError('"{}" cannot be parsed into a correct {} id'.format(event.path, self._entity_name))
        return ctl_id

    @staticmethod
    def _get_namespace_id_from_ctl_id(ctl_id):
        """
        :type ctl_id: (six.text_type, six.text_type)
        :rtype: six.text_type
        """
        if isinstance(ctl_id, tuple):
            return ctl_id[0]
        return ctl_id

    @staticmethod
    def _make_event_path_from_ctl_id(*ctl_id):
        """
        :type ctl_id: six.text_type
        :rtype: six.text_type
        """
        return '/{}'.format('/'.join(ctl_id))

    def _get_ctl_human_readable_name(self, ctl_id):
        """
        :type ctl_id: (six.text_type, six.text_type)
        """
        return '{}("{}:{}") ctl'.format(self._entity_name, *ctl_id)

    def _get_standoff_strategy(self, ctl_id):
        """
        :type ctl_id: (six.text_type, six.text_type)
        :rtype Union[int, float]
        """
        if six.PY3:
            return default_standoff_strategy().__next__
        else:
            return default_standoff_strategy().next

    def _get_acquire_timeout_strategy(self, ctl_id):
        """
        :type ctl_id: (six.text_type, six.text_type)
        :rtype Union[int, float]
        """
        if six.PY3:
            return default_acquire_timeout_strategy().__next__
        else:
            return default_acquire_timeout_strategy().next

    def _create_ctl_singleton(self, ctl_id):
        """
        :type ctl_id: (six.text_type, six.text_type)
        """
        ctl = self._create_ctl(ctl_id)
        if not ctl:
            return
        return exclusiveservice2.ExclusiveService(
            coord=self._coord,
            name=ctl.name,
            runnable=ctl,
            acquire_timeout_strategy=self._get_acquire_timeout_strategy(ctl_id),
            standoff_strategy=self._get_standoff_strategy(ctl_id),
            lock_release_timeout=get_lock_release_timeout(),
            greenthread_aware_mode=True,
            metrics_registry=metrics.ROOT_REGISTRY,
            disable_per_exclusive_service_metrics=True,
        )

    def _start_ctl(self, ctx, ctl_id):
        """
        :type ctx: context.OpCtx
        :type ctl_id: (six.text_type, six.text_type)
        """
        with self._ctl_start_stop_lock:
            name = self._get_ctl_human_readable_name(ctl_id)
            if ctl_id in self._ctls:
                ctx.log.info('Singleton for %s is already running', name)
                return
            ctx.log.info('Starting singleton for %s...', name)
            ctl = self._create_ctl_singleton(ctl_id)
            if ctl:
                self._ctls[ctl_id] = ctl
                ctl.start()
                ctx.log.info('Started singleton for %s', name)

    def _stop_ctl(self, ctx, ctl_id):
        """
        :type ctx: context.OpCtx
        :type ctl_id: (six.text_type, six.text_type)
        """
        with self._ctl_start_stop_lock:
            name = self._get_ctl_human_readable_name(ctl_id)
            if ctl_id not in self._ctls:
                ctx.log.info('Singleton for %s is already not running', name)
                return
            ctx.log.info('Stopping singleton for %s...', name)
            self._ctls[ctl_id].stop()
            del self._ctls[ctl_id]
            ctx.log.info('Stopped singleton for %s...', name)

    def _start(self):
        self._cache.bind_on_specific_events(self._callback, self._subscribed_events)
        ctx = context.OpCtx(log=self._log, op_id=rndstr())
        ctx.log.info('Starting all ctl singletons...')
        with self._ctl_start_stop_lock:
            for event in self._yield_starting_events():
                self._process_event(event, ctx)
        ctx.log.info('Started all ctl singletons...')

    def _stop(self):
        self._cache.unbind_from_specific_events(self._callback, self._subscribed_events)
        ctx = context.OpCtx(log=self._log, op_id=rndstr())
        ctx.log.info('Stopping all ctl singletons...')
        with self._ctl_start_stop_lock:
            for ctl_id in list(six.iterkeys(self._ctls)):  # can't use six.iterkeys because we modify the dict
                self._stop_ctl(ctx, ctl_id)
            self._ctls.clear()
        self._events_queue = gevent.queue.Queue()
        ctx.log.info('Stopped all ctl singletons')

    def _wait_event(self):
        try:
            self._events_queue.peek()
        except gevent.queue.Empty:
            pass

    def _get_event(self):
        peek = gevent.spawn(self._wait_event)
        gevent.wait([self._stopped, peek], count=1)
        if self._stopped.is_set():
            self._events_queue.put(STOPPED)  # to cause `peek` greenlet to stop
            return STOPPED
        else:
            return self._events_queue.get()

    def _run(self):
        self._log.info('Running...')
        try:
            self._start()
            while 1:
                event = self._get_event()
                if event is STOPPED:
                    break
                self._process_event(event)
        finally:
            self._stop()

    def stop(self, timeout=10, ignore_greenlet_exit=True, log=None, kill_timeout=None):
        self._stopped.set()
        try:
            if self.is_running():
                self.wait(timeout=timeout)
        except gevent.Timeout:
            pass
        finally:
            super(CtlManagerV2, self).stop(ignore_greenlet_exit=ignore_greenlet_exit,
                                           log=log,
                                           kill_timeout=kill_timeout)

    def run(self):
        self._stopped.clear()
        while 1:
            try:
                self._run()
            except LossExceptions:
                timeout = self._get_sleep_after_exception_timeout()
                self._log.debug(u'Lost connection to zk, sleeping for %s seconds', timeout)
                gevent.sleep(timeout)
            except UNEXPECTED_EXCEPTIONS:
                timeout = self._get_sleep_after_exception_timeout()
                self._log.exception(u'Unexpected exception while running, sleeping for %s seconds...', timeout)
                gevent.sleep(timeout)
            else:
                break


class PartyingCtlManager(CtlManager):
    CONTENDING_MEMBERS = 3

    def __init__(self, coord, cache, member_id, party_suffix, name, starting_events, stopping_events):
        """
        :type coord: awacs.lib.zookeeper_client.ZookeeperClient
        :type cache: awacs.model.cache.AwacsCache
        :type member_id: six.text_type
        :type party_suffix: six.text_type
        :type name: six.text_type
        :type starting_events: tuple[type, ...]
        :type stopping_events: tuple[type, ...]
        """
        super(PartyingCtlManager, self).__init__(coord, cache, name, starting_events, stopping_events)
        self.member_id = member_id
        party_path = os.path.join(PARTIES_NODE_ZK_PREFIX, self.name + party_suffix)
        self.party = SingletonParty(coord.client, party_path, identifier=self.member_id)
        self.hashring = None

    def _get_ctl_human_readable_name(self, ctl_id):
        raise NotImplementedError

    def _list_all_ctl_ids(self):
        raise NotImplementedError

    def _create_ctl(self, ctl_id):
        raise NotImplementedError

    def _get_ctl_id_from_event(self, event):
        raise NotImplementedError

    def _is_assigned(self, ctl_id):
        """
        :type ctl_id: tuple
        :rtype: bool
        """
        key = '-'.join(ctl_id)
        assigned_member_ids = {node['nodename']
                               for node in self.hashring.range(key=key, size=self.CONTENDING_MEMBERS)}
        return self.member_id in assigned_member_ids

    def _list_all_assigned_ctl_ids(self):
        for ctl_id in self._list_all_ctl_ids():
            if self._is_assigned(ctl_id):
                yield ctl_id

    def _is_starting_event(self, event):
        if super(PartyingCtlManager, self)._is_starting_event(event):
            ctl_id = self._get_ctl_id_from_event(event)
            return self._is_assigned(ctl_id)
        else:
            return False

    @staticmethod
    def _treecache_event_type_to_name(event_type):
        if event_type == treecache.TreeEvent.CONNECTION_LOST:
            event_type_name = 'CONNECTION_LOST'
        elif event_type == treecache.TreeEvent.CONNECTION_SUSPENDED:
            event_type_name = 'CONNECTION_SUSPENDED'
        elif event_type == treecache.TreeEvent.CONNECTION_RECONNECTED:
            event_type_name = 'CONNECTION_RECONNECTED'
        elif event_type == treecache.TreeEvent.INITIALIZED:
            event_type_name = 'INITIALIZED'
        else:
            raise AssertionError()
        return event_type_name

    def _process_event(self, event):
        if isinstance(event, TreeEvent):
            if event.event_type in (treecache.TreeEvent.CONNECTION_LOST, treecache.TreeEvent.CONNECTION_SUSPENDED):
                op_id = rndstr()
                event_type_name = self._treecache_event_type_to_name(event.event_type)
                self._log.info('%s: received TreeEvent(type=%s)', op_id, event_type_name)
                ctl_ids_to_stop = set(self._ctls)
                self._log.info('%s: going to stop %d ctls', op_id, len(ctl_ids_to_stop))
                for ctl_id in ctl_ids_to_stop:
                    self._stop_ctl(ctl_id)
                self._log.info('%s: stopped %d ctls', op_id, len(ctl_ids_to_stop))
            elif event.event_type in (treecache.TreeEvent.CONNECTION_RECONNECTED, treecache.TreeEvent.INITIALIZED):
                op_id = rndstr()
                event_type_name = self._treecache_event_type_to_name(event.event_type)
                self._log.info('%s: received TreeEvent(type=%s)', op_id, event_type_name)
                self._log.info('%s: reconnecting to the party...', op_id)
                self._join_party(op_id)
                member_ids = self.party.list_member_ids()
                self._log.info('%s: reconnected to the party, member ids: %s', op_id, ', '.join(member_ids))
                self.hashring = uhashring.HashRing(nodes=member_ids)
                running_ctl_ids = set(self._ctls)
                assigned_ctl_ids = set(self._list_all_assigned_ctl_ids())
                ctl_ids_to_stop = running_ctl_ids - assigned_ctl_ids
                ctl_ids_to_start = assigned_ctl_ids - running_ctl_ids
                self._log.info('%s: going to start %d ctls and stop %d ctls',
                               op_id, len(ctl_ids_to_start), len(ctl_ids_to_stop))
                for ctl_id in ctl_ids_to_stop:
                    self._stop_ctl(ctl_id)
                for ctl_id in ctl_ids_to_start:
                    self._start_ctl(ctl_id)
                self._log.info('%s: started %d ctls and stopped %d ctls',
                               op_id, len(ctl_ids_to_start), len(ctl_ids_to_stop))

        elif isinstance(event, PartyMemberUpdate):
            op_id = rndstr()
            member_id = event.member_id
            self._log.info('%s: received PartyMemberUpdate(member_id=%s)', op_id, member_id)
            nodenames = self.hashring.get_nodes()
            self._log.info('%s: hashring nodes before adding %s: %s', op_id, member_id, nodenames)
            if member_id in nodenames:
                self._log.info('%s: %s is already in hashring, do nothing', op_id, member_id)
                return
            self.hashring.add_node(member_id)
            nodenames = self.hashring.get_nodes()
            self._log.info('%s: hashring nodes after adding %s: %s', op_id, member_id, nodenames)
            running_ctl_ids = set(self._ctls)
            assigned_ctl_ids = set(self._list_all_assigned_ctl_ids())
            ctl_ids_to_stop = running_ctl_ids - assigned_ctl_ids
            ctl_ids_to_start = assigned_ctl_ids - running_ctl_ids
            self._log.info('%s: added party member %s, going to stop %d ctls',
                           op_id, member_id, len(ctl_ids_to_stop))
            assert not ctl_ids_to_start
            n = len(nodenames)
            if n > self.CONTENDING_MEMBERS and not ctl_ids_to_stop:
                self._log.warn('%s: %d > %d and no ctls to stop, how come?', op_id, n, self.CONTENDING_MEMBERS)
            for ctl_id in ctl_ids_to_stop:
                self._stop_ctl(ctl_id)
            self._log.info('%s: stopped %d ctls', op_id, len(ctl_ids_to_stop))

        elif isinstance(event, PartyMemberRemove):
            op_id = rndstr()
            member_id = event.member_id
            self._log.info('%s: received PartyMemberRemove(member_id=%s)', op_id, member_id)
            if member_id in self._cache.list_party_member_ids(event.party_id):
                # ignore expiring entries from previous runs finished with lost connections
                self._log.warn('%s: member is still in the party: %s', op_id, member_id)
                return
            nodenames = self.hashring.get_nodes()
            self._log.info('%s: hashring nodes before removing %s: %s', op_id, member_id, nodenames)
            try:
                self.hashring.remove_node(member_id)
            except KeyError as e:
                self._log.warn('%s: failed to remove %s from hashring: %s', op_id, member_id, e)
            else:
                nodenames = self.hashring.get_nodes()
                self._log.info('%s: hashring nodes after removing %s: %s', op_id, member_id, nodenames)
            running_ctl_ids = set(self._ctls)
            assigned_ctl_ids = set(self._list_all_assigned_ctl_ids())
            ctl_ids_to_stop = running_ctl_ids - assigned_ctl_ids
            ctl_ids_to_start = assigned_ctl_ids - running_ctl_ids
            assert not ctl_ids_to_stop
            self._log.info('%s: party member %s left, going to start %d ctls',
                           op_id, member_id, len(ctl_ids_to_start))
            n = len(nodenames)
            if n > self.CONTENDING_MEMBERS and not ctl_ids_to_start:
                self._log.warn('%s: %d > %d and no ctls to start, how come?', op_id, n, self.CONTENDING_MEMBERS)
            for ctl_id in ctl_ids_to_start:
                self._start_ctl(ctl_id)
            self._log.info('%s: started %d ctls', op_id, len(ctl_ids_to_start))

        else:
            ctl_id = self._get_ctl_id_from_event(event)
            if self._is_starting_event(event):
                self._start_ctl(ctl_id)
            elif self._is_stopping_event(event):
                self._stop_ctl(ctl_id)

    def _callback(self, event):
        party_changed = isinstance(event, TreeEvent) or (
                isinstance(event, (PartyMemberUpdate, PartyMemberRemove)) and
                event.party_id == self.name and
                event.member_id != self.member_id
        )
        starting_or_stopping = self._is_starting_event(event) or self._is_stopping_event(event)
        if party_changed or starting_or_stopping:
            self._events_queue.put(event)

    def _join_party(self, op_id):
        self._log.info('%s: joining the party...', op_id)
        self.party.join()
        self._log.info('%s: joined the party', op_id)
        member_ids = self.party.list_member_ids()
        self._log.info('%s: member ids: %s', op_id, ', '.join(member_ids))
        return member_ids

    def _leave_party(self, op_id):
        self._log.info('%s: leaving the party...', op_id)
        rv = self.party.leave()
        self._log.info('%s: left the party, rv: %s', op_id, rv)
        return rv

    def _start(self):
        op_id = rndstr()
        member_ids = self._join_party(op_id)
        self.hashring = uhashring.HashRing(nodes=member_ids)
        self._log.info('%s: starting', op_id)
        super(PartyingCtlManager, self)._start()
        self._log.info('%s: started, running %s ctls', op_id, len(self._ctls))

    def _stop(self):
        op_id = rndstr()
        self._log.info('%s: stopping', op_id)
        super(PartyingCtlManager, self)._stop()
        self._leave_party(op_id)
        self.hashring = None
        self._log.info('%s: stopped', op_id)


class PartyingCtlManagerV2(six.with_metaclass(abc.ABCMeta, CtlManagerV2)):
    CONTENDING_MEMBERS = 3

    def __init__(self, coord, cache, member_id, party_suffix, entity_name, subscribed_events,
                 allowed_namespace_id_matcher=None):
        """
        :type coord: awacs.lib.zookeeper_client.ZookeeperClient
        :type cache: awacs.model.cache.AwacsCache
        :type member_id: six.text_type
        :type party_suffix: six.text_type
        :type entity_name: six.text_type
        :type subscribed_events: tuple[type, ...]
        :type allowed_namespace_id_matcher: Optional[Callable]
        """
        super(PartyingCtlManagerV2, self).__init__(
            coord, cache, entity_name, subscribed_events, allowed_namespace_id_matcher)
        self._internal_events = (PartyMemberUpdate, PartyMemberUpdate, TreeEvent)
        self._subscribed_events = tuple(list(subscribed_events) + list(self._internal_events))
        self.member_id = member_id
        party_path = os.path.join(PARTIES_NODE_ZK_PREFIX, self.name + party_suffix)
        self.party = SingletonParty(coord.client, party_path, identifier=self.member_id)
        self.hashring = None

    def _is_assigned(self, ctl_id):
        """
        :type ctl_id: (six.text_type, six.text_type)
        :rtype: bool
        """
        key = '-'.join(ctl_id)
        assigned_member_ids = {node['nodename'] for node in self.hashring.range(key=key, size=self.CONTENDING_MEMBERS)}
        return self.member_id in assigned_member_ids

    def _yield_assigned_starting_ctl_ids(self):
        for event in self._yield_starting_events():
            assert isinstance(event, self._subscribed_events)
            ctl_id = self._get_ctl_id_from_event(event)
            if not self._is_allowed_namespace(self._get_namespace_id_from_ctl_id(ctl_id)):
                continue
            if not self._is_assigned(ctl_id):
                continue
            if self._should_ctl_be_running(event):
                yield ctl_id

    @staticmethod
    def _treecache_event_type_to_name(event_type):
        if event_type == treecache.TreeEvent.CONNECTION_LOST:
            event_type_name = 'CONNECTION_LOST'
        elif event_type == treecache.TreeEvent.CONNECTION_SUSPENDED:
            event_type_name = 'CONNECTION_SUSPENDED'
        elif event_type == treecache.TreeEvent.CONNECTION_RECONNECTED:
            event_type_name = 'CONNECTION_RECONNECTED'
        elif event_type == treecache.TreeEvent.INITIALIZED:
            event_type_name = 'INITIALIZED'
        else:
            raise AssertionError()
        return event_type_name

    def _process_event(self, event, ctx=None):
        if not isinstance(event, self._internal_events):
            ctl_id = self._get_ctl_id_from_event(event)
            if not self._is_allowed_namespace(self._get_namespace_id_from_ctl_id(ctl_id)):
                return
            elif not self._is_assigned(ctl_id):
                return
            ctx = ctx or context.OpCtx(log=self._log, op_id=rndstr())
            if self._should_ctl_be_running(event):
                return self._start_ctl(ctx, ctl_id)
            else:
                return self._stop_ctl(ctx, ctl_id)
        ctx = ctx or context.OpCtx(log=self._log, op_id=rndstr())
        if isinstance(event, TreeEvent):
            if event.event_type in (treecache.TreeEvent.CONNECTION_LOST, treecache.TreeEvent.CONNECTION_SUSPENDED):
                event_type_name = self._treecache_event_type_to_name(event.event_type)
                ctx.log.info('received TreeEvent(type=%s)', event_type_name)
                ctl_ids_to_stop = set(self._ctls)
                ctx.log.info('going to stop %d ctls', len(ctl_ids_to_stop))
                for ctl_id in ctl_ids_to_stop:
                    self._stop_ctl(ctx, ctl_id)
                ctx.log.info('stopped %d ctls', len(ctl_ids_to_stop))
            elif event.event_type in (treecache.TreeEvent.CONNECTION_RECONNECTED, treecache.TreeEvent.INITIALIZED):
                event_type_name = self._treecache_event_type_to_name(event.event_type)
                ctx.log.info('received TreeEvent(type=%s)', event_type_name)
                ctx.log.info('reconnecting to the party...')
                self._join_party(ctx)
                member_ids = self.party.list_member_ids()
                ctx.log.info('reconnected to the party, member ids: %s', ', '.join(member_ids))
                self.hashring = uhashring.HashRing(nodes=member_ids)
                running_ctl_ids = set(self._ctls)
                assigned_ctl_ids = set(self._yield_assigned_starting_ctl_ids())
                ctl_ids_to_stop = running_ctl_ids - assigned_ctl_ids
                ctl_ids_to_start = assigned_ctl_ids - running_ctl_ids
                ctx.log.info('going to start %d ctls and stop %d ctls', len(ctl_ids_to_start), len(ctl_ids_to_stop))
                for ctl_id in ctl_ids_to_stop:
                    self._stop_ctl(ctx, ctl_id)
                for ctl_id in ctl_ids_to_start:
                    self._start_ctl(ctx, ctl_id)
                ctx.log.info('started %d ctls and stopped %d ctls', len(ctl_ids_to_start), len(ctl_ids_to_stop))

        elif isinstance(event, PartyMemberUpdate):
            member_id = event.member_id
            ctx.log.info('received PartyMemberUpdate(member_id=%s)', member_id)
            nodenames = self.hashring.get_nodes()
            ctx.log.info('hashring nodes before adding %s: %s', member_id, nodenames)
            if member_id in nodenames:
                ctx.log.info('%s is already in hashring, do nothing', member_id)
                return
            self.hashring.add_node(member_id)
            nodenames = self.hashring.get_nodes()
            ctx.log.info('hashring nodes after adding %s: %s', member_id, nodenames)
            running_ctl_ids = set(self._ctls)
            assigned_ctl_ids = set(self._yield_assigned_starting_ctl_ids())
            ctl_ids_to_stop = running_ctl_ids - assigned_ctl_ids
            ctl_ids_to_start = assigned_ctl_ids - running_ctl_ids
            self._log.info('added party member %s, going to stop %d ctls', member_id, len(ctl_ids_to_stop))
            assert not ctl_ids_to_start
            n = len(nodenames)
            if n > self.CONTENDING_MEMBERS and not ctl_ids_to_stop:
                ctx.log.warn('%d > %d and no ctls to stop, how come?', n, self.CONTENDING_MEMBERS)
            for ctl_id in ctl_ids_to_stop:
                self._stop_ctl(ctx, ctl_id)
            ctx.log.info('stopped %d ctls', len(ctl_ids_to_stop))

        elif isinstance(event, PartyMemberRemove):
            member_id = event.member_id
            ctx.log.info('received PartyMemberRemove(member_id=%s)', member_id)
            if member_id in self._cache.list_party_member_ids(event.party_id):
                # ignore expiring entries from previous runs finished with lost connections
                ctx.log.warn('member is still in the party: %s', member_id)
                return
            nodenames = self.hashring.get_nodes()
            ctx.log.info('hashring nodes before removing %s: %s', member_id, nodenames)
            try:
                self.hashring.remove_node(member_id)
            except KeyError as e:
                ctx.log.warn('failed to remove %s from hashring: %s', member_id, e)
            else:
                nodenames = self.hashring.get_nodes()
                ctx.log.info('hashring nodes after removing %s: %s', member_id, nodenames)
            running_ctl_ids = set(self._ctls)
            assigned_ctl_ids = set(self._yield_assigned_starting_ctl_ids())
            ctl_ids_to_stop = running_ctl_ids - assigned_ctl_ids
            ctl_ids_to_start = assigned_ctl_ids - running_ctl_ids
            assert not ctl_ids_to_stop
            ctx.log.info('party member %s left, going to start %d ctls', member_id, len(ctl_ids_to_start))
            n = len(nodenames)
            if n > self.CONTENDING_MEMBERS and not ctl_ids_to_start:
                ctx.log.warn('%d > %d and no ctls to start, how come?', n, self.CONTENDING_MEMBERS)
            for ctl_id in ctl_ids_to_start:
                self._start_ctl(ctx, ctl_id)
            ctx.log.info('started %d ctls', len(ctl_ids_to_start))

    def _callback(self, event):
        should_process = True
        if isinstance(event, (PartyMemberUpdate, PartyMemberRemove)):
            party_changed = event.party_id == self.name and event.member_id != self.member_id
            should_process = party_changed
        if should_process:
            self._events_queue.put(event)

    def _join_party(self, ctx):
        ctx.log.info('joining the party...')
        self.party.join()
        ctx.log.info('joined the party')
        member_ids = self.party.list_member_ids()
        ctx.log.info('member ids: %s', ', '.join(member_ids))
        return member_ids

    def _leave_party(self, ctx):
        ctx.log.info('leaving the party...')
        rv = self.party.leave()
        ctx.log.info('left the party, rv: %s', rv)
        return rv

    def _start(self):
        self._cache.bind_on_specific_events(self._callback, self._subscribed_events)
        ctx = context.OpCtx(log=self._log, op_id=rndstr())
        member_ids = self._join_party(ctx)
        self.hashring = uhashring.HashRing(nodes=member_ids)
        ctx.log.info('Starting all ctl singletons...')
        with self._ctl_start_stop_lock:
            for ctl_id in self._yield_assigned_starting_ctl_ids():
                self._start_ctl(ctx, ctl_id)
        ctx.log.info('started, running %s ctls', len(self._ctls))

    def _stop(self):
        ctx = context.OpCtx(log=self._log, op_id=rndstr())
        ctx.log.info('stopping')
        super(PartyingCtlManagerV2, self)._stop()
        self._leave_party(ctx)
        self.hashring = None
        ctx.log.info('stopped')


class ContextedCtl(six.with_metaclass(abc.ABCMeta, greenthread.GreenThread)):
    EVENTS_QUEUE_GET_TIMEOUT = .1  # seconds
    SLEEP_AFTER_EXCEPTION_TIMEOUT = 20
    SLEEP_AFTER_EXCEPTION_TIMEOUT_JITTER = 10

    def __init__(self, name):
        super(ContextedCtl, self).__init__()

        self.name = name
        self._log = logging.getLogger(name)
        self._events_queue = gevent.queue.Queue()
        self._busy = False
        self._started = False
        self._last_processed_at = 0
        self._stopped = gevent.event.Event()
        self._pb = None

    def __repr__(self):
        return self.name

    def _get_sleep_after_exception_timeout(self):
        return get_jittered_value(self.SLEEP_AFTER_EXCEPTION_TIMEOUT, self.SLEEP_AFTER_EXCEPTION_TIMEOUT_JITTER)

    @property
    def _pb_generation(self):
        return self._pb.meta.generation if self._pb else 'None'

    @staticmethod
    def _should_discard_event(ctx, event, cached_pb):
        """
        Discard an event with a generation older than on the cached object

        :type ctx: context.OpCtx
        :param event:
        :param cached_pb:
        :rtype: bool
        """
        cached_pb_generation = cached_pb.meta.generation
        event_pb_generation = event.pb.meta.generation
        if event_pb_generation < cached_pb_generation:
            ctx.log.debug('Skipped event with stale generation %s, because object in cache has generation %s',
                          event_pb_generation, cached_pb_generation)
            return True
        return False

    @property
    def busy(self):
        return self._busy

    @property
    def started(self):
        return self._started

    @property
    def has_empty_queue(self):
        return self._events_queue.empty()

    @abc.abstractmethod
    def _accept_event(self, event):
        """
        :type event: awacs.model.events.*
        :rtype: bool
        """
        raise NotImplementedError

    @abc.abstractmethod
    def _start(self, ctx):
        """
        :type ctx: context.OpCtx
        :return:
        """
        raise NotImplementedError

    @abc.abstractmethod
    def _stop(self):
        raise NotImplementedError

    @abc.abstractmethod
    def _process_event(self, ctx, event):
        """
        :type ctx: context.OpCtx
        :type event: awacs.model.events.*
        :rtype: bool
        :returns: True if the ctl has to be stopped or nothing otherwise
        """
        raise NotImplementedError

    @abc.abstractmethod
    def _process_empty_queue(self, ctx):
        """
        :type ctx: context.OpCtx
        :rtype: bool
        :returns: True if the ctl has to be stopped or nothing otherwise
        """
        raise NotImplementedError

    def _callback(self, event):
        """
        :type event: awacs.model.events.*
        """
        if self._accept_event(event):
            self._events_queue.put(event)

    def _do_start(self, ctx):
        """
        :type ctx: context.OpCtx
        """
        self._busy = True
        try:
            self._start(ctx)
        finally:
            self._busy = False
        self._started = True

    def _do_stop(self):
        self._busy = True
        try:
            rv = self._stop()
        finally:
            self._busy = False
        self._started = False
        return rv

    def _do_process_event(self, ctx, event):
        """
        :type ctx: context.OpCtx
        :type event: awacs.model.events.*
        :rtype: bool
        :returns: True if the ctl has to be stopped or nothing otherwise
        """
        self._busy = True
        try:
            return self._process_event(ctx, event)
        finally:
            self._busy = False

    def _do_process_empty_queue(self, ctx):
        """
        :type ctx: context.OpCtx
        :rtype: bool
        :returns: True if the ctl has to be stopped or nothing otherwise
        """
        self._busy = True
        try:
            return self._process_empty_queue(ctx)
        finally:
            self._busy = False

    def _get_event(self):
        def _wait_event():
            try:
                self._events_queue.peek(timeout=self.EVENTS_QUEUE_GET_TIMEOUT)
            except gevent.queue.Empty:
                pass

        peek = gevent.spawn(_wait_event)
        gevent.wait([self._stopped, peek], count=1)
        if self._stopped.is_set():
            return STOPPED
        elif not self._events_queue.empty():
            return self._events_queue.get()
        else:
            return EMPTY

    def _run(self):
        self._log.info('running...')
        self._stopped.clear()
        root_ctx = context.BackgroundCtx()
        ctx, cancel = root_ctx.with_cancel()

        def do_cancel(_):
            cancel('ctl stopped')

        self._stopped.rawlink(do_cancel)
        try:
            self._log.info('starting...')
            self._do_start(ctx.with_op(op_id=rndstr(), log=self._log))
            self._log.info('started')
            while 1:
                event = self._get_event()
                if event == STOPPED:
                    need_to_stop = True
                elif event == EMPTY:
                    op_ctx = ctx.with_op(op_id=rndstr(), log=self._log)
                    need_to_stop = self._do_process_empty_queue(op_ctx)
                else:
                    op_ctx = ctx.with_op(op_id=rndstr(), log=self._log)
                    op_ctx.log.debug('processing event %s', type(event).__name__)
                    need_to_stop = self._do_process_event(op_ctx, event)
                if need_to_stop:
                    break
        finally:
            self._stopped.unlink(do_cancel)
            self._do_stop()

    def run(self):
        while 1:
            try:
                self._run()
            except LossExceptions:
                timeout = self._get_sleep_after_exception_timeout()
                self._log.debug(u'Lost connection to zk, sleeping for %s seconds', timeout)
                gevent.sleep(timeout)
            except UNEXPECTED_EXCEPTIONS as e:
                timeout = self._get_sleep_after_exception_timeout()
                self._log.exception(u'Unexpected exception while running: %s, '
                                    u'sleeping for %s seconds...', e, timeout)
                gevent.sleep(timeout)
            else:
                break
