import inject
import six
from sepelib.core import config as appconfig

from awacs.lib import l3mgrclient
from awacs.lib.vectors import cacheutil
from awacs.model import cache
from awacs.model.l3_balancer import l3_balancer, events, l3mgr
from infra.awacs.proto import model_pb2
from infra.swatlib import orly_client


orly_brake = orly_client.OrlyBrake(
    rule=u'awacs-push-l3-config',
    metrics_registry=events.L3_CTL_REGISTRY)

DEFAULT_MAX_SKIP_COUNT = 3


class TransportV2(object):
    __slots__ = (u'_namespace_id', u'_l3_balancer_id')

    _l3mgr_client = inject.attr(l3mgrclient.IL3MgrClient)  # type: l3mgrclient.L3MgrClient
    _cache = inject.attr(cache.IAwacsCache)  # type: cache.AwacsCache

    def __init__(self, namespace_id, l3_balancer_id):
        self._namespace_id = namespace_id
        self._l3_balancer_id = l3_balancer_id

    def _is_ready_to_transport(self, ctx, state_handler, vectors):
        if appconfig.get_value(u'run.disable_l3_transport', default=False):
            ctx.log.warning(u'not transporting: l3 transport is disabled in config')
            return False

        l3_balancer_pb = self._cache.must_get_l3_balancer(self._namespace_id, self._l3_balancer_id)
        if l3_balancer_pb.meta.transport_paused.value:
            ctx.log.debug(u'not transporting: l3 transport is paused')
            return False

        if vectors.valid.l3_balancer_version is None:
            ctx.log.debug(u'not transporting: no valid L3 version in vector')
            return False

        if state_handler.ignore_existing_l3mgr_config:
            ctx.log.debug(u'"ignore_existing_l3mgr_config" is set, will create new config')
            return True

        l3mgr_config_pb = state_handler.get_in_progress_l3mgr_config_pb()
        if l3mgr_config_pb is not None:
            ctx.log.debug(u'not transporting: previous config is already in progress: "%s:%s',
                          l3mgr_config_pb.service_id, l3mgr_config_pb.config_id)
            return False

        return True

    def transport(self, ctx, state_handler, l3_vectors):
        """
        Update virtual servers and real servers in L3 Manager.

        Ways to update RS:
          1) /vs - create a new VS
          2) /editrs - update real servers of an *active* config
          3) /vs/{vs.id} - update existing VS; discouraged by TT: https://st.yandex-team.ru/TRAFFIC-12199
        Ways to update VS:
          1) /vs - create a new VS
          2) /vs/{vs.id} - update existing VS; discouraged by TT: https://st.yandex-team.ru/TRAFFIC-12199
        Additional consideration about /editrs: behind the scenes it creates new virtual servers
        with updated RS, and then creates a new config with these VS. So it's absolutely identical to
        manually calling multiple "POST /vs" + one "POST /config".
        The common denominator here is /vs, so we'll be using it to update both VS and RS.
        The transporting flow:
        - Compare virtual servers in awacs L3 spec to the latest config in L3 Manager (including their real servers)
        - For each VS with differences, create a new VS
        - Create a new config that:
          1) Includes unchanged VS
          2) Includes new VS that we created earlier
          3) Does not include VS that are present in L3mgr, but not present in awacs L3 spec
        - Process that config
        - Periodically check if that config is active in L3mgr. While we wait, we don't transport anything new.

        This can be simplified when L3mgr implements a new API:
        https://st.yandex-team.ru/TRAFFIC-12209

        :type state_handler: l3_balancer.L3BalancerStateHandler
        :param l3_vectors
        :type ctx: context.OpCtx
        """
        ctx = ctx.with_op(op_id=u'transport')

        if l3_vectors.active == l3_vectors.valid:
            ctx.log.debug(u'active_vector == valid_vector, nothing to transport')
            return

        if not self._is_ready_to_transport(ctx, state_handler, l3_vectors):
            return
        ctx.log.debug(u'transporting to L3mgr: %s', l3_vectors.valid)

        l3_balancer_spec_pb = cacheutil.must_get_l3_balancer_revision_spec_with_cache(
            self._cache, self._namespace_id, self._l3_balancer_id, l3_vectors.valid.l3_balancer_version.version)
        l3mgr_service_id = l3_balancer_spec_pb.l3mgr_service_id
        with events.reporter.report_error(ctx, u'failed to get service from L3mgr'):
            l3mgr_service = l3mgr.Service.from_api(self._l3mgr_client, l3mgr_service_id)
        latest_config = l3mgr.ServiceConfig.latest_from_api_if_exists(self._l3mgr_client, l3mgr_service_id)

        rs_group = self._make_rs_group(l3_vectors.valid, l3_balancer_spec_pb)
        if latest_config is None or state_handler.ignore_existing_l3mgr_config:
            l3mgr_vs = l3mgr.VirtualServers(l3mgr_service.id, {}, {})
        else:
            l3mgr_vs = self._get_l3mgr_virtual_servers(ctx, l3mgr_service, latest_config,
                                                       allow_l3mgr_configs_mismatch=l3_balancer_spec_pb.enforce_configs)
        config_needs_update = self._update_virtual_servers(l3_balancer_spec_pb, l3mgr_vs, rs_group)

        force_process = not l3mgr_service.has_active_config or l3_balancer_spec_pb.skip_tests.value
        if config_needs_update or latest_config is None:
            ctx.log.debug(u'need to update config in L3mgr')
            in_progress_config = self._update_vs_in_config(ctx, l3_vectors, l3mgr_service, latest_config, l3mgr_vs,
                                                           force_process=force_process)
        elif latest_config is not None and latest_config.is_new:
            ctx.log.debug(u'valid vector already matches non-active L3mgr config "%s:%s", will process it',
                          l3mgr_service.id, latest_config.id)
            latest_config.process(self._l3mgr_client, force=force_process)
            in_progress_config = latest_config
        else:
            ctx.log.debug(u'valid vector already matches L3mgr config "%s:%s"', l3mgr_service.id, latest_config.id)
            in_progress_config = latest_config

        state_handler.mark_versions_as_in_progress(versions=list(l3_vectors.valid),
                                                   l3mgr_config_pb=in_progress_config.to_pb())

    def poll_configs(self, ctx, state_handler, l3_vectors):
        """
        :type ctx: context.OpCtx
        :type state_handler: l3_balancer.L3BalancerStateHandler
        :param l3_vectors
        """
        ctx = ctx.with_op(op_id=u'poll_configs')

        l3mgr_config_pb = state_handler.get_in_progress_l3mgr_config_pb()
        if l3mgr_config_pb is None:
            return

        ctx.log.debug(u'polling config "%s:%s"...', l3mgr_config_pb.service_id, l3mgr_config_pb.config_id)
        with events.reporter.report_error(ctx, u'failed to get config from L3mgr'):
            config = l3mgr.ServiceConfig.from_api_if_exists(self._l3mgr_client,
                                                            l3mgr_config_pb.service_id,
                                                            l3mgr_config_pb.config_id)
        if config is None:
            ctx.log.warning(u'config is missing from l3mgr: "%s:%s"',
                            l3mgr_config_pb.service_id,
                            l3mgr_config_pb.config_id)
            return

        if config.is_active:
            ctx.log.debug(u'config is active in L3mgr: "%s:%s"', l3mgr_config_pb.service_id, l3mgr_config_pb.config_id)
            state_handler.handle_l3mgr_config_activation(l3_vectors.in_progress)
            ctx.log.debug(u'marked vector as active')

    def skip_stuck(self, ctx, state_handler, l3_vectors):
        """
        :type ctx: context.OpCtx
        :type state_handler: l3_balancer.L3BalancerStateHandler
        :param l3_vectors
        """
        ctx = ctx.with_op(op_id=u'skip_stuck_configs')

        if l3_vectors.in_progress.is_empty():
            return
        in_progress_vector_hash = l3_vectors.in_progress.get_weak_hash_str()

        ctx.log.debug(u'in_progress_vector: %s', l3_vectors.in_progress)

        max_skip_count = appconfig.get_value(u'run.max_l3_config_skip_count', default=DEFAULT_MAX_SKIP_COUNT)
        skip_count = state_handler.get_skip_stuck_count(in_progress_vector_hash)
        if skip_count > max_skip_count:
            ctx.log.warning(u'skip count for vector %s is %s and over allowed limit of %s, NOT skipping',
                            in_progress_vector_hash, skip_count, max_skip_count)
            events.reporter.report(events.L3MgrEvent.CONFIG_SKIP_IS_NOT_ALLOWED, ctx=ctx)
            return

        l3mgr_config_pb = state_handler.get_in_progress_l3mgr_config_pb()
        if l3mgr_config_pb is None:
            return
        with events.reporter.report_error(ctx, u'failed to get config from L3mgr'):
            config = l3mgr.ServiceConfig.from_api_if_exists(self._l3mgr_client,
                                                            l3mgr_config_pb.service_id,
                                                            l3mgr_config_pb.config_id)

        if config is None:
            ctx.log.warning(u'Config is missing from l3mgr: "%s:%s", skipping',
                            l3mgr_config_pb.service_id,
                            l3mgr_config_pb.config_id)
            skip_reason = u'NOT_FOUND'
        elif config.state in (u'TEST_FAIL', u'VCS_FAIL'):
            ctx.log.debug(u'Forcing creation of fresh config, because config "%s:%s" has bad state "%s"',
                          l3mgr_config_pb.service_id, l3mgr_config_pb.config_id, config.state)
            skip_reason = config.state
        else:
            return

        events.reporter.report(events.L3MgrEvent.CONFIG_SKIPPED, ctx=ctx)
        if skip_reason == u'TEST_FAIL':
            events.reporter.report(events.L3MgrEvent.CONFIG_SKIPPED_DUE_TO_TEST_FAIL, ctx=ctx)
        elif skip_reason == u'VCS_FAIL':
            events.reporter.report(events.L3MgrEvent.CONFIG_SKIPPED_DUE_TO_VCS_FAIL, ctx=ctx)
        elif skip_reason == u'NOT_FOUND':
            events.reporter.report(events.L3MgrEvent.CONFIG_SKIPPED_DUE_TO_NOT_FOUND, ctx=ctx)

        state_handler.reset_in_progress_vector(l3_vectors.in_progress,
                                               author=u'awacs',
                                               comment=u'Forcing creation of fresh config in L3Mgr')

    def _update_vs_in_config(self, ctx, l3_vectors, l3mgr_service, latest_config, l3mgr_vs, force_process):
        """
        :type ctx: context.OpCtx
        :param l3_vectors:
        :type l3mgr_service: l3mgr.Service
        :type latest_config: l3mgr.ServiceConfig | None
        :type l3mgr_vs: l3mgr.VirtualServers
        :type force_process: bool
        :return:
        """
        vs_ids = l3mgr_vs.vs_ids
        ctx.log.info(u'Saving config: svc_id="%s", vs_ids="%s"', l3mgr_service.id, vs_ids)
        self._maybe_apply_orly_brake(ctx, op_id=l3_vectors.valid.get_weak_hash_str())
        with events.reporter.report_error(ctx, u'failed to create config'):
            if latest_config is None:
                new_config = l3mgr.ServiceConfig.create_and_process(self._l3mgr_client,
                                                                    l3mgr_service.id,
                                                                    vs_ids,
                                                                    force_process=True)
            else:
                new_config = l3mgr.ServiceConfig.update_vs_and_process(self._l3mgr_client,
                                                                       l3mgr_service.id,
                                                                       latest_config,
                                                                       vs_ids,
                                                                       force_process=force_process)
        ctx.log.debug(u'created l3mgr config: "%s:%s"', l3mgr_service.id, new_config.id)
        events.reporter.report(events.L3MgrEvent.CONFIG_CREATED, ctx=ctx)
        return new_config

    def _make_rs_group(self, vector, l3_balancer_spec_pb):
        rs_group = l3mgr.RSGroup()
        included_backend_ids = l3_balancer.get_included_backend_ids(self._namespace_id, l3_balancer_spec_pb)
        for full_es_id, endpoint_set_version in six.iteritems(vector.endpoint_set_versions):
            if endpoint_set_version.deleted:
                continue
            if full_es_id not in vector.backend_versions:
                continue
            if vector.backend_versions[full_es_id].deleted:
                continue
            if full_es_id not in included_backend_ids:
                continue
            ns_id, es_id = full_es_id
            es_spec_pb = cacheutil.must_get_endpoint_set_revision_with_cache(
                self._cache,
                namespace_id=ns_id,
                endpoint_set_id=es_id,
                version=endpoint_set_version.version).spec
            for instance_pb in es_spec_pb.instances:  # type: model_pb2.EndpointSetSpec.Instance
                fqdn = instance_pb.host
                ip = instance_pb.ipv6_addr or instance_pb.ipv4_addr
                weight = l3_balancer.get_instance_weight(l3_balancer_spec_pb, instance_pb)
                rs_group.add(fqdn, ip, weight)
        return rs_group

    def _get_l3mgr_virtual_servers(self, ctx, l3mgr_service, latest_config, allow_l3mgr_configs_mismatch):
        if latest_config.id == l3mgr_service.config_id:
            # current config is active, so we can use VS directly from the service
            with events.reporter.report_error(ctx):
                return l3mgr.VirtualServers.from_l3mgr_raw_virtual_servers(
                    l3mgr_service.id,
                    l3mgr_service.virtual_servers,
                    allow_l3mgr_configs_mismatch=allow_l3mgr_configs_mismatch)
        ctx.log.debug(u'latest config "%s" is not active yet, fetching VS individually by IDs: %s',
                      latest_config.id, latest_config.vs_ids)
        with events.reporter.report_error(ctx):
            return l3mgr.VirtualServers.from_api(self._l3mgr_client, l3mgr_service.id, latest_config.vs_ids,
                                                 allow_l3mgr_configs_mismatch=allow_l3mgr_configs_mismatch)

    def _update_virtual_servers(self, l3_balancer_spec_pb, l3mgr_vs, rs_group):
        config_needs_update = False

        # replace updated VS
        used_vs = set()
        for vs_pb in l3_balancer_spec_pb.virtual_servers:
            config_needs_update |= l3mgr_vs.add_or_modify_vs(self._l3mgr_client, vs_pb, rs_group)
            used_vs.add((vs_pb.ip, vs_pb.port))

        # remove unused VS from L3mgr
        for (ip, port) in list(l3mgr_vs.virtual_servers.keys()):
            if (ip, port) not in used_vs:
                l3mgr_vs.virtual_servers.pop((ip, port))
                config_needs_update = True

        return config_needs_update

    def _maybe_apply_orly_brake(self, ctx, op_id):
        ctx.log.debug(u'orly operation id: %s', op_id)
        orly_brake.maybe_apply(op_id=op_id, op_log=ctx.log, op_labels=[
            (u'namespace-id', self._namespace_id),
            (u'l3-balancer-id', self._l3_balancer_id),
        ])
