import collections
import itertools

import inject
import six
from sepelib.core import config as appconfig

from awacs.lib import l7heavy_client
from awacs.model import cache, objects, util
from awacs.model.l7heavy_config import consts, events
from infra.swatlib import orly_client


orly_brake = orly_client.OrlyBrake(
    rule=u'awacs-push-l7heavy-config',
    metrics_registry=events.L7HEAVY_CTL_REGISTRY)


class L7HeavyConfigTransport(object):
    _l7heavy_client = inject.attr(l7heavy_client.IL7HeavyClient)  # type: l7heavy_client.L7HeavyClient

    @classmethod
    def _is_ready_to_transport(cls, ctx, vectors):
        if appconfig.get_value(u'run.disable_l7heavy_transport', default=False):
            ctx.log.warning(u'not transporting: l7heavy transport is disabled in config')
            return False

        if vectors.valid.l7heavy_config_version is None:
            ctx.log.debug(u'not transporting: no valid L3 version in vector')
            return False

        l7heavy_config_pb = objects.L7HeavyConfig.cache.must_get(*vectors.valid.l7heavy_config_version.id)
        if l7heavy_config_pb.meta.transport_paused.value:
            ctx.log.debug(u'not transporting: l7heavy transport is paused')
            return False

        return True

    @staticmethod
    def _maybe_apply_orly_brake(ctx, op_id, namespace_id, l7heavy_config_id):
        ctx.log.debug(u'orly operation id: %s', op_id)
        orly_brake.maybe_apply(op_id=op_id, op_log=ctx.log, op_labels=[
            (u'namespace-id', namespace_id),
            (u'l7heavy-config-id', l7heavy_config_id),
        ])

    @staticmethod
    def is_section_synced(current, target):
        def key(loc):
            return loc['id']

        if sorted(target['locations'], key=key) != sorted(current['locations'], key=key):
            return False
        if sorted(target['fallback_locations'], key=key) != sorted(current['fallback_locations'], key=key):
            return False
        if target['exclude_from_bulk_actions'] != current.get('exclude_from_bulk_actions'):
            return False
        current_logins = current.get('managers', {}).get('logins', ())
        for login in target['managers']['logins']:
            if login not in current_logins:
                return False
        current_groups = current.get('managers', {}).get('groups', ())
        for group in target['managers']['logins']:
            if group not in current_groups:
                return False
        return True

    @classmethod
    def _sync_config(cls, ctx, l7heavy_config_spec_pb):
        with events.reporter.report_error(ctx, u'failed to get config from L7Heavy'):
            version, config = cls._l7heavy_client.get_config(l7heavy_config_spec_pb.l7_heavy_config_id)

        if config['group_id'] != l7heavy_config_spec_pb.group_id:
            config['group_id'] = l7heavy_config_spec_pb.group_id
            with events.reporter.report_error(ctx, u'failed to upload config to L7Heavy'):
                cls._l7heavy_client.update_config(l7heavy_config_spec_pb.l7_heavy_config_id, version, config)

    @classmethod
    def _sync_sections(cls, ctx, vectors, l7heavy_config_id, ns_id, l7hc_id):
        with events.reporter.report_error(ctx, u'failed to get config sections from L7Heavy'):
            version, sections = cls._l7heavy_client.get_config_sections(l7heavy_config_id)

        with events.reporter.report_error(ctx, u'failed to get ITS version'):
            current_its_version = cls._l7heavy_client.get_its_version(l7heavy_config_id)

        current_weights_by_section_and_location = collections.defaultdict(dict)
        sections_by_id = {}
        for section in sections:
            for location in itertools.chain(section['locations'], section.get('fallback_locations', ())):
                current_weights_by_section_and_location[section['id']][location['id']] = location['weight']
                sections_by_id[section['id']] = section

        namespace_pb = cache.IAwacsCache.instance().must_get_namespace(ns_id)
        target_sections = []
        need_update = False
        for full_ws_id in sorted(vectors.valid.weight_section_versions):
            spec_pb = objects.WeightSection.find_rev_spec(vectors.valid.weight_section_versions[full_ws_id])
            ws_id = full_ws_id[1]

            locations = []
            fallback_locations = []
            for location_pb in sorted(spec_pb.locations, key=lambda loc: loc.name):
                if location_pb.name in current_weights_by_section_and_location[ws_id]:
                    # Do not change current weights for existing locations
                    weight = current_weights_by_section_and_location[ws_id][location_pb.name]
                elif ws_id in sections_by_id:
                    # Add new locations with zero weight to existing section
                    weight = 0
                else:
                    # Create new sections with current weights = defaults
                    weight = location_pb.default_weight

                loc = {
                    'id': location_pb.name,
                    'default_weight': location_pb.default_weight,
                    'weight': weight
                }
                if location_pb.is_fallback:
                    fallback_locations.append(loc)
                else:
                    locations.append(loc)

            target_section = {
                'id': ws_id,
                'exclude_from_bulk_actions': spec_pb.exclude_from_bulk_actions.value,
                'locations': locations,
                'fallback_locations': fallback_locations,
                'managers': {
                    'logins': list(namespace_pb.spec.its.acl.logins),
                    'groups': list(namespace_pb.spec.its.acl.staff_group_ids) + [consts.MARTY_STAFF_GROUP_ID]
                }
            }
            target_sections.append(target_section)

            if ws_id in sections_by_id and not cls.is_section_synced(sections_by_id[ws_id], target_section):
                need_update = True
        if not need_update and len(sections) == len(target_sections):
            return

        cls._maybe_apply_orly_brake(ctx, vectors.valid.get_weak_hash_str(), ns_id, l7hc_id)
        with events.reporter.report_error(ctx, u'failed to update config sections in L7Heavy'):
            new_version, _ = cls._l7heavy_client.save_sections(l7heavy_config_id, version, target_sections)

        with events.reporter.report_error(ctx, u'failed to push L7Heavy weights to ITS'):
            cls._l7heavy_client.push_weights_to_its(l7heavy_config_id, new_version, current_version=current_its_version)

    @classmethod
    def transport(cls, ctx, state_handler, vectors):
        ctx = ctx.with_op(op_id=u'transport')

        if vectors.active == vectors.valid:
            ctx.log.debug(u'active_vector == valid_vector, nothing to transport')
            return

        if not cls._is_ready_to_transport(ctx, vectors):
            return
        ctx.log.debug(u'transporting to l7heavy: %s', vectors.valid)

        l7heavy_config_spec_pb = objects.L7HeavyConfig.find_rev_spec(vectors.valid.l7heavy_config_version)
        ns_id, l7hc_id = vectors.valid.l7heavy_config_version.id

        if l7heavy_config_spec_pb.group_id == util.COMMON_L7HEAVY_GROUP_ID:
            cls._sync_sections(ctx, vectors, l7heavy_config_spec_pb.l7_heavy_config_id, ns_id, l7hc_id)
            cls._sync_config(ctx, l7heavy_config_spec_pb)
        else:
            cls._sync_config(ctx, l7heavy_config_spec_pb)
            cls._sync_sections(ctx, vectors, l7heavy_config_spec_pb.l7_heavy_config_id, ns_id, l7hc_id)

        return state_handler.mark_version_as_active(vectors.valid)
