# coding: utf-8
import random

import gevent.event
import gevent.queue
import gevent.threadpool
import inject
import monotonic
from sepelib.core import config

from awacs.lib import ctlmanager
from awacs.lib.context import OpLoggerAdapter
from awacs.lib.order_processor.model import has_actionable_spec
from awacs.lib.strutils import flatten_full_id
from awacs.model import events, zk, cache, util, objects
from awacs.model.balancer import validator, discoverer, transport, registry
from awacs.model.balancer.vector import EndpointSetVersion
from awacs.model.util import get_balancer_location
from infra.swatlib.logutil import rndstr


class BalancerCtl(ctlmanager.ContextedCtl):
    TRANSPORT_PROCESSING_INTERVAL = 30
    TRANSPORT_POLLING_INTERVAL = 30
    TRANSPORT_MAIN_LOOP_FREQ = 10

    PROCESS_INTERVAL = 25
    PROCESS_INTERVAL_JITTER = 10
    FORCE_PROCESS_INTERVAL = 500
    FORCE_PROCESS_INTERVAL_JITTER = 120
    EVENTS_QUEUE_GET_TIMEOUT = 10
    JUST_STARTED_PERIOD = FORCE_PROCESS_INTERVAL + 2 * FORCE_PROCESS_INTERVAL_JITTER

    _processed_counter = registry.L7_CTL_REGISTRY.get_counter(u'processed')
    _force_processed_counter = registry.L7_CTL_REGISTRY.get_counter(u'force-processed')

    _zk = inject.attr(zk.IZkStorage)  # type: zk.ZkStorage
    _cache = inject.attr(cache.IAwacsCache)  # type: cache.AwacsCache

    def __init__(self, namespace_id, balancer_id):
        name = 'balancer-ctl("{}:{}")'.format(namespace_id, balancer_id)
        super(BalancerCtl, self).__init__(name)

        self._namespace_id = namespace_id
        self._balancer_id = balancer_id
        self._full_balancer_id = (namespace_id, balancer_id)
        self._is_large = util.is_large_balancer(namespace_id, balancer_id)

        self._validator = None  # type: validator.BalancerValidator or None
        self._transport = None  # type: transport.BalancerTransport or None
        self._discoverer = None  # type: discoverer.BalancerDiscoverer or None
        self._threadpool = None

        self._last_process_namespace_update_at = None
        self._process_pending_since = None
        self._force_revalidation = False

        self._balancer_path = '{}/{}'.format(self._namespace_id, self._balancer_id)
        self._namespace_path_prefix = '{}/'.format(self._namespace_id)

    def _accept_event(self, event):
        """
        :type event: events.*
        :rtype: bool
        """
        if isinstance(event, events.BalancerUpdate):
            should_process = event.path == self._balancer_path and has_actionable_spec(event.pb)
        elif isinstance(event, (events.BalancerStateUpdate, events.BalancerStateRemove)):
            should_process = event.path == self._balancer_path
        elif isinstance(event, (events.EndpointSetUpdate, events.BackendUpdate)):
            should_process = event.path.startswith(self._namespace_path_prefix) or event.pb.spec.is_global.value
        elif isinstance(event, (
                events.EndpointSetRemove, events.BackendRemove,
                events.CertUpdate, events.CertRemove,
                events.DomainUpdate, events.DomainRemove,
                events.UpstreamUpdate, events.UpstreamRemove,
                objects.WeightSection.cache.legacy_update_event, objects.WeightSection.cache.legacy_remove_event,
                events.KnobUpdate, events.KnobRemove,
                events.ComponentUpdate, events.ComponentRemove,
        )):
            should_process = event.path.startswith(self._namespace_path_prefix)
        else:
            should_process = False

        if should_process:
            return not self._is_deleted()
        return False

    def _start(self, ctx):
        ctx.log.debug('Starting...')

        self._threadpool = None
        if self._is_large:
            self._threadpool = gevent.threadpool.ThreadPool(1)

        balancer_pb = self._cache.must_get_balancer(namespace_id=self._namespace_id, balancer_id=self._balancer_id)
        self._validator = validator.BalancerValidator(
            namespace_id=self._namespace_id,
            balancer_id=self._balancer_id,
            threadpool=self._threadpool)
        balancer_location = get_balancer_location(balancer_pb, logger=ctx.log)
        self._discoverer = discoverer.BalancerDiscoverer(
            namespace_id=self._namespace_id,
            balancer_id=self._balancer_id,
            balancer_location=balancer_location)
        self._transport = transport.BalancerTransport(
            namespace_id=self._namespace_id,
            balancer_id=self._balancer_id,
            processing_interval=self.TRANSPORT_PROCESSING_INTERVAL,
            polling_interval=self.TRANSPORT_POLLING_INTERVAL,
            main_loop_freq=self.TRANSPORT_MAIN_LOOP_FREQ,
            log=self._log,
        )

        self._cache.bind(self._callback)

        balancer_state_pb = self._zk.must_get_balancer_state(namespace_id=self._namespace_id,
                                                             balancer_id=self._balancer_id)
        ctx.log.debug('Balancer state generation: %d', balancer_state_pb.generation)

        self._validator.set_balancer_state_pb(balancer_state_pb)
        self._transport.set_balancer_state_pb(balancer_state_pb, ctx)
        self._discoverer.set_balancer_state_pb(balancer_state_pb, ctx)

        if self._is_large:
            ctx.log.debug('Balancer is considered as large, polling snapshots...')
            try:
                # it's useful to poll snapshots, find some of them activated and clear the balancer state
                # by removing in_progress statuses and old active revisions before we start validation
                self._transport.poll_snapshots(ctx)
            except ctlmanager.UNEXPECTED_EXCEPTIONS as e:
                ctx.log.exception('Failed to poll snapshots on start: %s', e)

        self._process_pending_since = None
        self._started_at = self._last_process_namespace_update_at = monotonic.monotonic()
        self._force_revalidation = config.get_value('run.force_revalidation_on_balancer_ctl_start', True)

        if not config.get_value('run.disable_transport', default=False):
            self._transport.start()

        ctx.log.debug('Started')

    def _stop(self):
        op_id = rndstr()
        op_log = OpLoggerAdapter(log=self._log, op_id=op_id)
        messages = [('Stopping...', ())]
        self._cache.unbind(self._callback)
        messages.append(('Unbound callback', ()))
        if self._threadpool is not None:
            messages.append(('Stopping threadpool', ()))
            c = 1
            while 1:
                try:
                    self._threadpool.kill()
                except gevent.GreenletExit:
                    c += 1
                else:
                    break
            messages.append(('Stopped threadpool in %d attempts', (c,)))
        if self._transport:
            messages.append(('Stopping transport', ()))
            attempts = self._transport.stop(ignore_greenlet_exit=True, log=op_log)
            messages.append(('Stopped transport in %d attempts', (attempts,)))
        messages.append(('Stopped', ()))
        for msg, arg in messages:
            op_log.debug(msg, *arg)

    def _handle_balancer_state_update(self, pb, ctx):
        is_state_updated = self._discoverer.handle_balancer_state_update(pb, ctx)
        if not is_state_updated:
            self._validator.handle_balancer_state_update(pb, ctx)
            self._transport.handle_balancer_state_update(pb, ctx)

    def _process(self, ctx):
        """
        :returns: Whether the balancer state has been updated.
        :rtype: bool
        """
        rv = self._discoverer.process_namespace_update(ctx)
        self._last_process_namespace_update_at = monotonic.monotonic()
        self._processed_counter.inc(1)
        return rv

    def _force_process(self, ctx):
        self._discoverer.process_namespace_update(ctx)
        balancer_state_pb = self._zk.must_get_balancer_state(namespace_id=self._namespace_id,
                                                             balancer_id=self._balancer_id)
        self._discoverer.handle_balancer_state_update(balancer_state_pb, ctx)
        self._validator.validate(ctx)
        self._last_process_namespace_update_at = monotonic.monotonic()
        self._force_processed_counter.inc(1)

    def _is_deleted(self):
        pb = self._cache.get_balancer(self._namespace_id, self._balancer_id)
        if not pb or pb.spec.deleted:
            return True

    def _event_should_force_revalidation(self, ctx, event):
        """
        There is a flaw in awacs models: the list of backend revisions, which endpoint set is valid for,
        is stored in EndpointSetMeta.backend_revisions field.
        Even if nothing changes (by nothing I mean that no specs changed, no revisions created),
        EndpointSetMeta.backend_revisions can easily affect the validity of current vector.

        Suppose someone changes backend's spec so that its endpoint set remains the same.
        Balancer controller sees the change before backend controller and adds it to the balancer state.

        And now, for new backend version to become valid, validator needs to be notified when
        backend controller sees the change and updated EndpointSetMeta.backend_revisions accordingly.

        We do this by setting self._force_revalidation to true.
        """
        endpoint_set_pb = event.pb
        endpoint_set_version = EndpointSetVersion.from_pb(endpoint_set_pb)
        endpoint_set_id = endpoint_set_version.endpoint_set_id

        endpoint_set_latest_versions = self._discoverer.endpoint_set_latest_versions
        endpoint_set_valid_versions = self._discoverer.endpoint_set_latest_valid_versions
        backend_latest_versions = self._discoverer.backend_latest_versions
        backend_valid_versions = self._discoverer.backend_latest_valid_versions

        curr_endpoint_set_version = endpoint_set_latest_versions.get(endpoint_set_id)
        valid_endpoint_set_version = endpoint_set_valid_versions.get(endpoint_set_id)
        curr_backend_version = backend_latest_versions.get(endpoint_set_id)
        valid_backend_version = backend_valid_versions.get(endpoint_set_id)

        backend_is_not_valid = curr_backend_version != valid_backend_version
        endpoint_set_is_not_valid = curr_endpoint_set_version != valid_endpoint_set_version
        if backend_is_not_valid or endpoint_set_is_not_valid:
            ctx.log.debug(
                '%s (version %s, gen %s) is present in balancer state but not valid, '
                'setting force_revalidation to true '
                '(backend_is_not_valid: %s, endpoint_set_is_not_valid: %s)...',
                flatten_full_id(self._namespace_id, endpoint_set_version.endpoint_set_id),
                endpoint_set_version.version[:10],
                endpoint_set_pb.meta.generation,
                backend_is_not_valid,
                endpoint_set_is_not_valid
            )
            return True
        return False

    def _is_backend_used(self, full_backend_id):
        return (full_backend_id in self._discoverer.backend_latest_versions or
                full_backend_id in self._discoverer.backend_latest_valid_versions or
                full_backend_id in self._discoverer.endpoint_set_latest_versions or
                full_backend_id in self._discoverer.endpoint_set_latest_valid_versions)

    def _process_event(self, ctx, event):
        if isinstance(event, events.BalancerStateUpdate):
            return self._handle_balancer_state_update(event.pb, ctx)

        if self._process_pending_since is None:
            if isinstance(event, (events.EndpointSetUpdate, events.BackendUpdate)) and event.pb.spec.is_global.value:
                # If we receive an update of a global backend or endpoint set, we schedule processing
                # if and only if we use this backend ("we use" == it's present in the balancer state)
                full_id = (event.pb.meta.namespace_id, event.pb.meta.id)
                if self._is_backend_used(full_id):
                    self._process_pending_since = monotonic.monotonic()
            else:
                self._process_pending_since = monotonic.monotonic()

        if isinstance(event, events.EndpointSetUpdate):
            self._force_revalidation = self._force_revalidation or self._event_should_force_revalidation(ctx, event)

    def _process_empty_queue(self, ctx):
        if self._is_deleted():
            return
        current_timer = monotonic.monotonic()

        process_interval = self.PROCESS_INTERVAL + random.randint(-self.PROCESS_INTERVAL_JITTER,
                                                                  self.PROCESS_INTERVAL_JITTER)
        if self._process_pending_since and current_timer - self._process_pending_since > process_interval:
            ctx.log.debug('processing events pending since %s, now is %s', self._process_pending_since, current_timer)
            self._process_pending_since = None
            is_state_updated = self._process(ctx)
            has_ctl_just_started = current_timer - self._started_at <= self.JUST_STARTED_PERIOD
            ctx.log.debug('processed events, is_state_updated: %s, has_ctl_just_started: %s',
                          is_state_updated, has_ctl_just_started)
            if not is_state_updated and not has_ctl_just_started and self._force_revalidation:
                ctx.log.debug('starting forced validation')
                try:
                    self._force_process(ctx)
                finally:
                    self._force_revalidation = False

        else:
            force_process_interval = self.FORCE_PROCESS_INTERVAL + random.randint(-self.FORCE_PROCESS_INTERVAL_JITTER,
                                                                                  self.FORCE_PROCESS_INTERVAL_JITTER)
            if (self._last_process_namespace_update_at and
                    current_timer - self._last_process_namespace_update_at > force_process_interval):
                ctx.log.debug('force processing')
                if self._force_revalidation:
                    ctx.log.debug('starting forced validation')
                    try:
                        self._force_process(ctx)
                    finally:
                        ctx.log.debug('finished forced validation')
                        self._force_revalidation = False
                else:
                    self._process(ctx)
            else:
                ctx.log.debug('Not processing yet: '
                              '_process_pending_since=%s, current_timer=%s process_interval=%s, '
                              'force_process_interval=%s, _last_process_namespace_update_at=%s',
                              self._process_pending_since, current_timer,
                              process_interval, force_process_interval, self._last_process_namespace_update_at)
