import random

import inject
import monotonic

from awacs.lib import ctlmanager
from awacs.model import events, cache, zk
from awacs.model.l3_balancer import discoverer_v2, transport_v2, validator_v2, l3_balancer


class L3BalancerCtlV2(ctlmanager.ContextedCtl):
    POLL_INTERVAL = 20
    PROCESS_INTERVAL = 30
    FORCE_PROCESS_INTERVAL = 60
    FORCE_PROCESS_INTERVAL_JITTER = 10
    EVENTS_QUEUE_GET_TIMEOUT = 5

    _zk = inject.attr(zk.IZkStorage)  # type: zk.ZkStorage
    _cache = inject.attr(cache.IAwacsCache)  # type: cache.AwacsCache

    _delayed_l3_events = (events.L3BalancerUpdate,)
    _immediate_l3_events = (events.L3BalancerStateUpdate,)
    _namespace_events = (events.BackendUpdate, events.EndpointSetUpdate,
                         events.BackendRemove, events.EndpointSetRemove)

    _l3_balancer_events = tuple(list(_delayed_l3_events) + list(_immediate_l3_events))
    _subscribed_events = tuple(list(_l3_balancer_events) + list(_namespace_events))

    def __init__(self, namespace_id, l3_balancer_id):
        name = u'l3-balancer-ctl-v2("{}:{}")'.format(namespace_id, l3_balancer_id)
        super(L3BalancerCtlV2, self).__init__(name)

        self._namespace_id = namespace_id
        self._l3_balancer_id = l3_balancer_id
        self._l3_balancer_path = u'{}/{}'.format(self._namespace_id, self._l3_balancer_id)
        self._namespace_prefix = self._namespace_id + u'/'

        self._init_processors()

        self._waiting_for_processing_since = None
        self._processed_at = None
        self._processing_deadline = None
        self._force_processing_deadline = None
        self._polled_at = None
        self._polling_deadline = None

    def _accept_event(self, event):
        """
        :type event: events.*
        :rtype: bool
        """
        if isinstance(event, self._l3_balancer_events):
            return event.path == self._l3_balancer_path
        if isinstance(event, self._namespace_events):
            return event.path.startswith(self._namespace_prefix)
        return False

    def _init_processors(self):
        self._discoverer = discoverer_v2.DiscovererV2(self._namespace_id, self._l3_balancer_id)
        self._validator = validator_v2.ValidatorV2(self._namespace_id, self._l3_balancer_id)
        self._transport = transport_v2.TransportV2(self._namespace_id, self._l3_balancer_id)

    def _start(self, ctx):
        self._cache.bind_on_specific_events(self._callback, self._subscribed_events)

    def _stop(self):
        self._cache.unbind_from_specific_events(self._callback, self._subscribed_events)

    def _process_event(self, ctx, event):
        if isinstance(event, self._immediate_l3_events):
            # trigger immediate processing of state updates, to improve the speed of disc->validate->transport loop
            self._force_processing_deadline = None
        else:
            # otherwise, delay processing to collect more updates
            if self._waiting_for_processing_since is None:
                # if we're not already waiting, set deadline
                self._waiting_for_processing_since = monotonic.monotonic()
                self._processing_deadline = self._waiting_for_processing_since + self.PROCESS_INTERVAL

    def _process_empty_queue(self, ctx):
        current_time = monotonic.monotonic()
        should_process = self._should_process(current_time)
        should_poll = self._should_poll(current_time)

        if not should_process and not should_poll:
            return

        l3_balancer_state_pb = self._zk.must_get_l3_balancer_state(self._namespace_id, self._l3_balancer_id)
        state_handler = l3_balancer.L3BalancerStateHandler(l3_balancer_state_pb)
        vectors = None
        if should_process:
            vectors = state_handler.generate_vectors()
            self._do_process(ctx, state_handler, vectors)
            self._reset_processing_timers()
        if state_handler.was_updated:
            ctx.log.debug(u'state was updated after _do_process')
            return

        if should_poll:
            vectors = vectors or state_handler.generate_vectors()
            self._do_poll(ctx, state_handler, vectors)
            self._reset_polling_timers()
        if state_handler.was_updated:
            ctx.log.debug(u'state was updated after _do_poll')
            return

    def _should_process(self, current_time):
        """
        :type current_time: float

        Process changes at most every self.PROCESS_INTERVAL seconds,
        and at least every self.FORCE_PROCESS_INTERVAL seconds
        """
        return (self._force_processing_deadline is None
                or current_time >= self._force_processing_deadline
                or (self._processing_deadline is not None and current_time >= self._processing_deadline)
                )

    def _do_process(self, ctx, state_handler, vectors):
        self._discoverer.discover(ctx, state_handler, vectors)
        if state_handler.was_updated:
            ctx.log.debug(u'state was updated after discovery')
            return
        self._validator.validate(ctx, state_handler, vectors)
        if state_handler.was_updated:
            ctx.log.debug(u'state was updated after validation')
            return
        self._transport.transport(ctx, state_handler, vectors)
        if state_handler.was_updated:
            ctx.log.debug(u'state was updated after transport')
            return
        self._transport.skip_stuck(ctx, state_handler, vectors)
        if state_handler.was_updated:
            ctx.log.debug(u'state was updated after skip_stuck')
            return

    def _reset_processing_timers(self):
        self._waiting_for_processing_since = None
        self._processing_deadline = None
        self._processed_at = monotonic.monotonic()
        self._force_processing_deadline = self._processed_at + random.randint(
            self.FORCE_PROCESS_INTERVAL - self.FORCE_PROCESS_INTERVAL_JITTER,
            self.FORCE_PROCESS_INTERVAL + self.FORCE_PROCESS_INTERVAL_JITTER)

    def _should_poll(self, current_time):
        """
        :type current_time: float
        """
        return self._polling_deadline is None or current_time >= self._polling_deadline

    def _do_poll(self, ctx, state_handler, vectors):
        self._transport.poll_configs(ctx, state_handler, vectors)
        if state_handler.was_updated:
            ctx.log.debug(u'state was updated after poll_configs')
            return
        if vectors.active.l3_balancer_version is None:
            return
        removed_versions = state_handler.remove_obsolete_versions(vectors.active)
        if state_handler.was_updated:
            ctx.log.debug(u'removed versions: %s', removed_versions)
            ctx.log.debug(u'state was updated after state_handler.remove_obsolete_versions')

    def _reset_polling_timers(self):
        self._polled_at = monotonic.monotonic()
        self._polling_deadline = self._polled_at + self.POLL_INTERVAL
