# coding: utf-8
import collections

import gevent
import itertools
import six
from google.protobuf import descriptor as pbdescriptor
from sepelib.core import config
from typing import List, Tuple
from six.moves import map

from awacs.lib import context
from awacs.lib.strutils import flatten_full_id, quote_join_sorted
from awacs.model.components import iter_balancer_components
from awacs.yamlparser.core import Builder
from infra.awacs.proto import modules_pb2, model_pb2
from infra.swatlib.pbutil import enum_value_to_name
from . import defs
from .config import Config
from .errors import ValidationError, KnobDoesNotExist, CertDoesNotExist
from .util import append_field_name_to_validation_error, Value, validate, validate_func_name_one_of


class DomainsSet(object):
    __slots__ = ('_common_list', 'yandex_tld', 'wildcard')

    def __init__(self):
        self._common_list = []  # type: list[(tuple[six.text_type, six.text_type], model_pb2.DomainSpec.Config)]
        self.yandex_tld = None  # type: (tuple[six.text_type, six.text_type], model_pb2.DomainSpec.Config)
        self.wildcard = None  # type: (tuple[six.text_type, six.text_type], model_pb2.DomainSpec.Config)

    def list_common_domains(self):
        return sorted(self._common_list)

    def add(self, domain_id, domain_config_pb):
        """
        :type domain_id: tuple[six.text_type, six.text_type]
        :type domain_config_pb: model_pb2.DomainSpec.Config
        """
        if domain_config_pb.type == domain_config_pb.COMMON:
            self._common_list.append((domain_id, domain_config_pb))
        elif domain_config_pb.type == domain_config_pb.YANDEX_TLD:
            assert self.yandex_tld is None
            self.yandex_tld = (domain_id, domain_config_pb)
        else:
            assert self.wildcard is None
            self.wildcard = (domain_id, domain_config_pb)

    def list_all_domains(self):
        """
        :rtype: List[Tuple[six.text_type, model_pb2.DomainSpec.Config]]
        """
        rv = sorted(self._common_list)
        if self.yandex_tld is not None:
            rv.append(self.yandex_tld)
        if self.wildcard is not None:
            rv.append(self.wildcard)
        return rv

    def empty(self):
        return not self.list_all_domains()


REGISTRY = {}

TYPE_PLAIN = 1
TYPE_WKT_WRAPPER = 2
TYPE_CALL = 3
TYPE_MESSAGE = 4
TYPE_MESSAGE_MAP = 5
TYPE_CUSTOM_SCALAR_MAP = 6  # for map-like fields like `repeated HeaderMapEntry` or `repeated CookieMapEntry`
TYPE_MESSAGE_SEQ = 7
TYPE_KNOB_REF = 8
TYPE_MESSAGE_ORDERED_MAP = 9
TYPE_CERT_REF = 10

COMPOSITE_FIELD_TYPES = (TYPE_MESSAGE, TYPE_MESSAGE_MAP, TYPE_MESSAGE_ORDERED_MAP,
                         TYPE_CUSTOM_SCALAR_MAP, TYPE_MESSAGE_SEQ)

# See https://github.com/google/protobuf/blob/master/src/google/protobuf/wrappers.proto
WKT_WRAPPER_MESSAGE_FULL_NAMES = frozenset((
    'google.protobuf.DoubleValue',
    'google.protobuf.FloatValue',
    'google.protobuf.Int64Value',
    'google.protobuf.UInt64Value',
    'google.protobuf.Int32Value',
    'google.protobuf.UInt32Value',
    'google.protobuf.BoolValue',
    'google.protobuf.StringValue',
    'google.protobuf.BytesValue',
))

AWACS_CALL_MESSAGE_FULL_NAME = 'awacs.modules.Call'
AWACS_KNOB_REF_MESSAGE_FULL_NAME = 'awacs.modules.KnobRef'
AWACS_CERT_REF_MESSAGE_FULL_NAME = 'awacs.modules.CertRef'

AWACS_HOLDER_MESSAGE_FULL_NAME = modules_pb2.Holder.DESCRIPTOR.full_name


def get_holder_pb_module(holder_pb):
    """
    :type holder_pb: modules_pb2.Holder
    """
    return getattr(holder_pb, holder_pb.WhichOneof('module'))


def get_field_type(field_desc):
    """
    :type field_desc: pbdescriptor.FieldDescriptor
    :returns: TYPE_*
    :rtype: int
    """
    if field_desc.type != field_desc.TYPE_MESSAGE:
        return TYPE_PLAIN

    message_desc = field_desc.message_type  # type: pbdescriptor.Descriptor
    full_name = message_desc.full_name
    if full_name in WKT_WRAPPER_MESSAGE_FULL_NAMES:
        return TYPE_WKT_WRAPPER
    elif full_name == AWACS_CALL_MESSAGE_FULL_NAME:
        return TYPE_CALL
    elif full_name == AWACS_KNOB_REF_MESSAGE_FULL_NAME:
        return TYPE_KNOB_REF
    elif full_name == AWACS_CERT_REF_MESSAGE_FULL_NAME:
        return TYPE_CERT_REF
    elif field_desc.label == field_desc.LABEL_REPEATED:
        if Builder.is_map_entry(message_desc) or Builder.is_ordered_map_entry(message_desc):
            entry_desc = field_desc.message_type  # type: pbdescriptor.Descriptor
            value_desc = entry_desc.fields_by_name['value']  # type: pbdescriptor.FieldDescriptor
            if value_desc.type == value_desc.TYPE_MESSAGE:
                if Builder.is_ordered_map_entry(message_desc):
                    return TYPE_MESSAGE_ORDERED_MAP
                else:
                    return TYPE_MESSAGE_MAP
            else:
                if Builder.is_ordered_map_entry(message_desc):
                    return TYPE_CUSTOM_SCALAR_MAP
                else:
                    return TYPE_PLAIN
        else:
            return TYPE_MESSAGE_SEQ
    else:
        return TYPE_MESSAGE


def get_field_types(message_desc):
    """
    :type message_desc: pbdescriptor.Descriptor
    :rtype: dict[str, int]
    """
    rv = {}
    for field_name, field_desc in six.iteritems(message_desc.fields_by_name):
        rv[field_name] = get_field_type(field_desc)
    return rv


def get_dynamic_field_names(message_desc):
    """
    :type message_desc: pbdescriptor.Descriptor
    :rtype: dict[str, set[int]]
    """
    rv = {}
    for field_name, field_desc in six.iteritems(message_desc.fields_by_name):
        for prefix in ('f_', 'k_', 'c_'):
            prefixed_field_name = prefix + field_name
            if prefixed_field_name not in message_desc.fields_by_name:
                continue
            prefixed_field_desc = message_desc.fields_by_name[prefixed_field_name]  # type: pbdescriptor.FieldDescriptor
            prefixed_field_type = get_field_type(prefixed_field_desc)
            if prefixed_field_type in (TYPE_CALL, TYPE_KNOB_REF, TYPE_CERT_REF):
                rv.setdefault(field_name, set()).add(prefixed_field_type)
    return rv


def list_composite_attrs(pb_field_to_class_attr_mapping, message_desc):
    """
    :type pb_field_to_class_attr_mapping: dict[str, str]
    :type message_desc: pbdescriptor.Descriptor
    :rtype: dict[str, int]
    """
    rv = {}
    for field_name, field_desc in six.iteritems(message_desc.fields_by_name):
        field_type = get_field_type(field_desc)
        if (field_type == TYPE_MESSAGE or
                field_type == TYPE_MESSAGE_MAP or
                field_type == TYPE_MESSAGE_ORDERED_MAP or
                field_type == TYPE_MESSAGE_SEQ):
            attr_name = pb_field_to_class_attr_mapping.get(field_name, field_name)
            rv[attr_name] = field_type
    return rv


def generate_sd_client_name(full_balancer_id):
    """
    :type full_balancer_id: (six.text_type, six.text_type)
    :rtype: six.text_type
    """
    return 'awacs-l7-balancer({}:{})'.format(*full_balancer_id)


class ValidationCtx(context.CancellableCtx):
    CONFIG_TYPE_FULL = 0
    CONFIG_TYPE_BALANCER = 1
    CONFIG_TYPE_UPSTREAM = 2

    DEFAULT_IDLE_PERIOD = 700

    def __init__(self,
                 config_type=CONFIG_TYPE_FULL,
                 knobs=None,
                 knob_version_hints=None,
                 knobs_mode=model_pb2.BalancerValidatorSettings.KNOBS_DISABLED,
                 certs=None,
                 domain_config_pbs=None,
                 upstream_spec_pbs=None,
                 weight_section_spec_pbs=None,
                 sd_client_name=None,
                 namespace_id=None,
                 modules_blacklist=None,
                 rps_limiter_allowed_installations=None,
                 gevent_idle_period=DEFAULT_IDLE_PERIOD,
                 components_pb=None,
                 balancer_env_type=None,
                 upstream_type=None):
        """
        :param int config_type: the type of config being validated (ValidationCtx.CONFIG_TYPE_*)
        :type knobs: dict[six.text_type, model_pb2.KnobSpec] | None
        :type knob_version_hints: dict[six.text_type, KnobVersion] | None
        :type knobs_mode: model_pb2.BalancerValidatorSettings.*
        :type certs: dict[six.text_type, model_pb2.CertificateSpec] | None
        :type domain_config_pbs: dict[(six.text_type, six.text_type), model_pb2.DomainSpec.Config] | None
        :type upstream_spec_pbs: dict[(six.text_type, six.text_type), model_pb2.UpstreamSpec] | None
        :type weight_section_spec_pbs: dict[(six.text_type, six.text_type), model_pb2.WeightSectionSpec] | None
        :type modules_blacklist: Iterable[six.text_type]
        :type rps_limiter_allowed_installations: Iterable[six.text_type]
        :type gevent_idle_period: int
        :type components_pb: model_pb2.BalancerSpec.ComponentsSpec | None
        :type balancer_env_type: model_pb2.BalancerSpec.EnvType | None
        :type upstream_type: model_pb2.YandexBalancerUpstreamSpec.Type | None
        """
        self.config_type = config_type
        self.knobs = knobs or {}
        self.knobs_mode = knobs_mode
        self.knob_version_hints = knob_version_hints or {}
        self.certs = certs or {}
        self.sd_client_name = sd_client_name
        self.domain_config_pbs = domain_config_pbs
        self.weight_section_spec_pbs = weight_section_spec_pbs or {}
        self.upstream_spec_pbs = upstream_spec_pbs or {}
        self.namespace_id = namespace_id
        self.modules_blacklist = modules_blacklist or ()
        self.rps_limiter_allowed_installations = rps_limiter_allowed_installations or ()
        self.gevent_idle_period = gevent_idle_period
        self.components_pb = components_pb if components_pb else model_pb2.BalancerSpec.ComponentsSpec()
        self.balancer_env_type = balancer_env_type
        self.upstream_type = upstream_type

        self._tick_counter = 0

        super(ValidationCtx, self).__init__()

    def maybe_gevent_idle(self):
        if not self.gevent_idle_period:
            return
        self._tick_counter += 1
        if self._tick_counter % self.gevent_idle_period == 0:
            gevent.idle()

    @staticmethod
    def _strip_namespace_id_from_dict_keys(namespace_id, spec_pbs):
        rv = {}
        for full_id, pb in six.iteritems(spec_pbs):
            assert full_id[0] == namespace_id
            rv[full_id[1]] = pb
        return rv

    @classmethod
    def create_ctx_with_config_type_upstream(cls, namespace_pb, full_upstream_id, upstream_spec_pb):
        """
        :type namespace_pb: model_pb2.Namespace
        :type full_upstream_id: (six.text_type, six.text_type)
        :type upstream_spec_pb: model_pb2.UpstreamSpec
        """
        return cls(
            config_type=ValidationCtx.CONFIG_TYPE_UPSTREAM,
            namespace_id=namespace_pb.meta.id,
            rps_limiter_allowed_installations=set(namespace_pb.spec.rps_limiter_allowed_installations.installations),
            upstream_type=upstream_spec_pb.yandex_balancer.type,
        )

    @classmethod
    def create_ctx_with_config_type_balancer(cls, namespace_pb, full_balancer_id, balancer_spec_pb):
        """
        :type namespace_pb: model_pb2.Namespace
        :type full_balancer_id: (six.text_type, six.text_type)
        :type balancer_spec_pb: model_pb2.BalancerSpec
        """
        return cls(
            config_type=cls.CONFIG_TYPE_BALANCER,
            namespace_id=namespace_pb.meta.id,
            knobs_mode=balancer_spec_pb.validator_settings.knobs_mode,
            rps_limiter_allowed_installations=set(namespace_pb.spec.rps_limiter_allowed_installations.installations),
            sd_client_name=generate_sd_client_name(full_balancer_id),
            components_pb=balancer_spec_pb.components,
            balancer_env_type=balancer_spec_pb.env_type
        )

    @classmethod
    def create_ctx_with_config_type_full(cls, namespace_pb, full_balancer_id, balancer_spec_pb,
                                         knob_spec_pbs, knob_version_hints, cert_spec_pbs, domain_spec_pbs,
                                         weight_section_spec_pbs, upstream_spec_pbs, disable_gevent_idle=False):
        """
        :type namespace_pb: model_pb2.Namespace
        :type full_balancer_id: (six.text_type, six.text_type)
        :type balancer_spec_pb: model_pb2.BalancerSpec
        :type knob_spec_pbs: dict[(six.text_type, six.text_type), model_pb2.KnobSpec]
        :type knob_version_hints: dict[(six.text_type, six.text_type), KnobVersion]
        :type cert_spec_pbs: dict[(six.text_type, six.text_type), model_pb2.CertSpec]
        :type domain_spec_pbs: dict[(six.text_type, six.text_type), model_pb2.DomainSpec]
        :type upstream_spec_pbs: dict[(six.text_type, six.text_type), model_pb2.UpstreamSpec] | None
        :type weight_section_spec_pbs: dict[(six.text_type, six.text_type), model_pb2.WeightSectionSpec]
        :type disable_gevent_idle: bool
        """
        knobs_mode = balancer_spec_pb.validator_settings.knobs_mode
        if knobs_mode == model_pb2.BalancerValidatorSettings.KNOBS_DISABLED:
            knob_spec_pbs = {}

        modules_blacklist = set()
        for module in config.get_value('run.modules_blacklist', []):
            if module['name'] not in namespace_pb.spec.modules_whitelist.modules:
                modules_blacklist.add(module['name'])

        rps_limiter_allowed_installations = set(namespace_pb.spec.rps_limiter_allowed_installations.installations)
        if disable_gevent_idle:
            gevent_idle_period = None
        else:
            gevent_idle_period = config.get_value('run.l7_validation_gevent_idle_period',
                                                  default=ValidationCtx.DEFAULT_IDLE_PERIOD)

        domain_config_pbs = {full_id: spec_pb.yandex_balancer.config
                             for full_id, spec_pb in six.iteritems(domain_spec_pbs)}

        namespace_id = namespace_pb.meta.id
        return cls(namespace_id=namespace_id,
                   config_type=cls.CONFIG_TYPE_FULL,
                   certs=cls._strip_namespace_id_from_dict_keys(namespace_id, cert_spec_pbs),
                   knobs=cls._strip_namespace_id_from_dict_keys(namespace_id, knob_spec_pbs),
                   knob_version_hints=cls._strip_namespace_id_from_dict_keys(namespace_id, knob_version_hints),
                   knobs_mode=balancer_spec_pb.validator_settings.knobs_mode,
                   modules_blacklist=modules_blacklist,
                   rps_limiter_allowed_installations=rps_limiter_allowed_installations,
                   domain_config_pbs=domain_config_pbs,
                   weight_section_spec_pbs=weight_section_spec_pbs,
                   upstream_spec_pbs=upstream_spec_pbs,
                   components_pb=balancer_spec_pb.components,
                   balancer_env_type=balancer_spec_pb.env_type,
                   sd_client_name=generate_sd_client_name(full_balancer_id),
                   gevent_idle_period=gevent_idle_period)

    def are_knobs_allowed(self):
        if self.config_type == self.CONFIG_TYPE_UPSTREAM:
            # when validating upstreams, we assume that knobs are allowed,
            # and do not actually validate anything
            return True
        else:
            return self.knobs_mode != model_pb2.BalancerValidatorSettings.KNOBS_DISABLED

    def are_knobs_enabled(self):
        return self.knobs_mode == model_pb2.BalancerValidatorSettings.KNOBS_ENABLED

    def validate_knob(self, knob_id, required_type):
        """
        :type knob_id: str
        :type required_type: model_pb2.KnobSpec.Type
        """
        if self.config_type == self.CONFIG_TYPE_UPSTREAM:
            # when validating upstreams, we assume that knobs are allowed,
            # and do not actually validate anything
            return
        if self.config_type == self.CONFIG_TYPE_BALANCER:
            # when validating balancer, we know whether knobs are allowed or not
            if not self.are_knobs_allowed():
                raise ValidationError('knobs are not allowed for this balancer')
        elif self.config_type == self.CONFIG_TYPE_FULL:
            if not self.are_knobs_allowed():
                raise ValidationError('knobs are not allowed for this balancer')
            if knob_id not in self.knobs:
                raise KnobDoesNotExist('knob "{}" is missing'.format(knob_id))
            knob_spec_pb = self.knobs[knob_id]
            if knob_spec_pb.deleted:
                raise KnobDoesNotExist('knob "{}" is missing'.format(knob_id))
            knob_type = knob_spec_pb.type
            if knob_spec_pb.mode == knob_spec_pb.WATCHED and knob_type == knob_spec_pb.ANY:
                # if the knob is watched (synced from ITS), its type can be unknown.
                # hence we consider it compatible with any type
                return
            if knob_type != required_type:
                type_enum_desc = model_pb2.KnobSpec.DESCRIPTOR.enum_types_by_name['Type']
                knob_type_str = enum_value_to_name(type_enum_desc, knob_type)
                required_type_str = enum_value_to_name(type_enum_desc, required_type)
                raise ValidationError('expected type is {}; actual type of knob "{}" is {}'.format(
                    required_type_str, knob_id, knob_type_str), hint=self.knob_version_hints.get(knob_id))
        else:
            raise AssertionError('Unknown config type {}'.format(self.config_type))

    def validate_cert(self, cert_id):
        if self.config_type in (self.CONFIG_TYPE_UPSTREAM, self.CONFIG_TYPE_BALANCER):
            # skip during synchronous validation
            return True
        if cert_id not in self.certs:
            raise CertDoesNotExist('cert "{}" is missing'.format(cert_id))

    def get_separate_http_and_https_domains(self):
        """
        :rtype: DomainsSet, DomainsSet
        """
        http_domains = DomainsSet()
        https_domains = DomainsSet()
        for domain_id, domain_config_pb in six.iteritems(self.domain_config_pbs):
            protocol = domain_config_pb.protocol
            if protocol in (model_pb2.DomainSpec.Config.HTTP_ONLY, model_pb2.DomainSpec.Config.HTTP_AND_HTTPS):
                http_domains.add(domain_id, domain_config_pb)
            if protocol in (model_pb2.DomainSpec.Config.HTTPS_ONLY, model_pb2.DomainSpec.Config.HTTP_AND_HTTPS):
                https_domains.add(domain_id, domain_config_pb)
        return http_domains, https_domains

    def validate_domain_certs(self):
        if self.config_type in (self.CONFIG_TYPE_UPSTREAM, self.CONFIG_TYPE_BALANCER):
            # skip during synchronous validation
            return True
        for (_, domain_id), domain_config_pb in sorted(six.iteritems(self.domain_config_pbs)):
            cert_id = domain_config_pb.cert.id
            if cert_id and cert_id not in self.certs:
                # historically awacs API allowed HTTP domains to have empty "cert.id" values,
                # so we don't want to reject them here
                raise ValidationError('cert "{}" not found for domain "{}"'.format(cert_id, domain_id))
            if domain_config_pb.HasField('secondary_cert'):
                secondary_cert_id = domain_config_pb.secondary_cert.id
                if secondary_cert_id not in self.certs:
                    raise ValidationError('secondary cert "{}" not found for domain "{}"'.format(secondary_cert_id,
                                                                                                 domain_id))

    def validate_domains(self):
        if self.config_type in (self.CONFIG_TYPE_UPSTREAM, self.CONFIG_TYPE_BALANCER):
            # skip during synchronous validation
            return True
        all_fqdns = collections.defaultdict(list)
        for (_, domain_id), domain_config_pb in six.iteritems(self.domain_config_pbs):
            for fqdn in itertools.chain(domain_config_pb.fqdns, domain_config_pb.shadow_fqdns):
                all_fqdns[fqdn].append(domain_id)
        for fqdn, domain_ids in six.iteritems(all_fqdns):
            if len(domain_ids) > 1:
                raise ValidationError('FQDN "{}" cannot be used in multiple domains: "{}"'.format(
                    fqdn, '", "'.join(domain_ids)))

    def ensure_component_version(self, component_type, min=None, max=None):
        component_config, component_pb = self._get_component_config_and_pb(component_type)
        min = component_config.parse_version(min) if min else None
        max = component_config.parse_version(max) if max else None
        self._validate_min_max_versions(min, max)
        if component_pb.state == component_pb.UNKNOWN:
            return True

        component_version = component_config.parse_version(component_pb.version)

        error_msgs = []
        error = False

        if min:
            error_msgs.append('>= {}'.format(min))
            if component_version < min:
                error = True

        if max:
            error_msgs.append('<= {}'.format(max))
            if component_version > max:
                error = True

        if error:
            msg = ' and '.join(error_msgs)
            raise ValidationError(
                'requires component {} of version {}, not {}'.format(
                    component_config.pb_field_name, msg, component_pb.version))

    def _get_component_config_and_pb(self, component_type):
        for component_config, component_pb in iter_balancer_components(self.components_pb):
            if component_config.type == component_type:
                return component_config, component_pb
        else:
            raise AssertionError('Unknown component_type: {}'.format(component_type))

    def _validate_min_max_versions(self, min, max):
        if not min and not max:
            raise RuntimeError('At least one of `min` and `max` should be specified')

        if min and max and min > max:
            raise RuntimeError('Minimum version must be less than or equal to max version')

    @property
    def ecc_certs_are_used(self):
        for cert_spec_pb in six.itervalues(self.certs):
            if cert_spec_pb.fields.public_key_info.algorithm_id == u'ec':
                return True
        return False

    def is_ecc_cert(self, cert_id):
        if not cert_id:
            return False
        self.validate_cert(cert_id)
        if self.certs[cert_id].fields.public_key_info.algorithm_id == u'ec':
            return True
        return False


DEFAULT_CTX = ValidationCtx()


class WrapperMeta(type):
    def __new__(mcs, name, bases, attrs):
        message_cls = attrs.get('__protobuf__')
        if message_cls is None:
            return type.__new__(mcs, name, bases, attrs)
        if message_cls in REGISTRY:
            raise RuntimeError('Wrapper for {} already exists: {!r}'.format(
                message_cls.__name__, REGISTRY[message_cls.DESCRIPTOR.full_name]))

        message_desc = message_cls.DESCRIPTOR  # type: pbdescriptor.Descriptor

        pb_field_to_cls_attr_mapping = attrs.get('PB_FIELD_TO_CLS_ATTR_MAPPING', {})
        composite_attrs = list_composite_attrs(pb_field_to_cls_attr_mapping, message_desc)

        slots = set(composite_attrs)  # we need slots for wrapped composite fields
        if '__slots__' in attrs:
            slots.update(attrs['__slots__'])  # we also want to preserve slots defined by users
        if bases == (object,):
            slots.add('__dict__')  # and we want our instances to be mockable (very useful in tests)
        attrs['__slots__'] = sorted(slots)

        # let's remove "default" class values of composite fields from class definition
        for attr in composite_attrs:
            if attr in attrs:
                del attrs[attr]

        klass = REGISTRY[message_cls.DESCRIPTOR.full_name] = type.__new__(mcs, name, bases, attrs)

        # precompute some things:
        field_types = get_field_types(message_desc)
        call_fields = {}  # type: dict[str, pbdescriptor.FieldDescriptor]
        knob_fields = {}  # type: dict[str, pbdescriptor.FieldDescriptor]
        cert_fields = {}  # type: dict[str, pbdescriptor.FieldDescriptor]
        for field_name, field_type in six.iteritems(field_types):
            if field_type == TYPE_CALL:
                call_fields[field_name] = message_desc.fields_by_name[field_name]
            elif field_type == TYPE_KNOB_REF:
                knob_fields[field_name] = message_desc.fields_by_name[field_name]
            elif field_type == TYPE_CERT_REF:
                cert_fields[field_name] = message_desc.fields_by_name[field_name]
        klass.__field_types__ = field_types
        klass.__composite_field_types__ = {field_name: field_type
                                           for field_name, field_type in six.iteritems(field_types)
                                           if is_composite_field_type(field_type)}
        klass.__field_type_items__ = list(klass.__field_types__.items())
        klass.__composite_field_type_items__ = list(klass.__composite_field_types__.items())
        klass.__dynamic_field_names__ = get_dynamic_field_names(message_desc)
        klass.__dynamic_field_name_items__ = list(klass.__dynamic_field_names__.items())
        klass.__composite_attrs__ = composite_attrs
        klass.__composite_attr_items__ = list(klass.__composite_attrs__.items())

        klass.__call_fields__ = call_fields
        klass.__knob_fields__ = knob_fields
        klass.__cert_fields__ = cert_fields
        return klass


def wrap(pb):
    """
    :type pb: awacs.proto.modules_pb2.*
    :rtype: ConfigWrapperBase

    Wraps protobuf message using corresponding wrapper.
    """
    return REGISTRY[pb.DESCRIPTOR.full_name](pb)


ANY_MODULE = object()


def find_module(modules, cls):
    for m in modules:
        if isinstance(m, cls):
            return m
    return None


def find_last_module(modules, cls):
    for m in reversed(modules):
        if isinstance(m, cls):
            return m
    return None


def add_module(modules, module):
    rv = list(modules)
    rv.append(module)
    return rv


def is_composite_field_type(field_type):
    return field_type in COMPOSITE_FIELD_TYPES


class WrapperBase(six.with_metaclass(WrapperMeta, object)):
    __slots__ = ('pb', '_calls', '_knobs', '_certs')

    __protobuf__ = None

    PB_FIELD_TO_CLS_ATTR_MAPPING = {}  # dict[str, str]

    # attr names:
    REQUIRED = []  # type: list[str]
    REQUIRED_ONEOFS = []  # type: list[list[str]]
    REQUIRED_ANYOFS = []  # type: list[list[str]]
    REQUIRED_ALLOFS = []  # type: list[list[str]]
    REQUIRED_PB_ONEOFS = []  # type: list[str]
    ALLOWED_KNOBS = {}
    ALLOWED_CERTS = ()
    DEFAULT_KNOB_IDS = {}

    MODULE_NAME = ''

    def __init__(self, pb):
        self.pb = None
        self.update_pb(pb=pb)

    def __eq__(self, other):
        return isinstance(other, WrapperBase) and self.pb == other.pb

    def _is_field_present(self, field_name):
        """
        :type field_name: six.text_type
        :rtype: bool
        """
        field_type = self.__field_types__[field_name]
        if field_type == TYPE_PLAIN:
            if field_name in self.__dynamic_field_names__:
                val = self.get(field_name).value  # one of Call, Knob, Cert, or scalar value
            else:
                val = getattr(self.pb, field_name)
        elif is_composite_field_type(field_type):
            attr_name = self.PB_FIELD_TO_CLS_ATTR_MAPPING.get(field_name, field_name)
            val = getattr(self, attr_name)
        elif field_type == TYPE_WKT_WRAPPER:
            val = self.pb.HasField(field_name)
        elif field_type == TYPE_CALL:
            raise RuntimeError('call field "{}" can not be required'.format(field_name))
        elif field_type == TYPE_KNOB_REF:
            raise RuntimeError('knob field "{}" can not be required'.format(field_name))
        elif field_type == TYPE_CERT_REF:
            raise RuntimeError('cert field "{}" can not be required'.format(field_name))
        else:
            raise RuntimeError('unknown field type: {}'.format(field_type))
        return bool(val)

    @classmethod
    def get_required_oneofs(cls):
        required_oneofs = []
        for oneof_name in cls.REQUIRED_PB_ONEOFS:
            desc = cls.__protobuf__.DESCRIPTOR  # type: pbdescriptor.Descriptor
            oneof_desc = desc.oneofs_by_name[oneof_name]  # type: pbdescriptor.OneofDescriptor
            required_oneofs.append([field_desc.name for field_desc in oneof_desc.fields])
        for field_names in cls.REQUIRED_ONEOFS:
            required_oneofs.append(field_names)
        return required_oneofs

    def auto_validate_required(self):
        for field_name in self.REQUIRED:
            is_field_present = self._is_field_present(field_name)
            if not is_field_present:
                raise ValidationError('is required', field_name=field_name)

        for field_names in self.REQUIRED_ANYOFS:
            if not any(map(self._is_field_present, field_names)):
                raise ValidationError('at least one of the "{}" '
                                      'must be specified'.format(quote_join_sorted(field_names)))

        for field_names in self.REQUIRED_ALLOFS:
            fields_presence = list(map(self._is_field_present, field_names))
            if any(fields_presence) and not all(fields_presence):
                raise ValidationError('either all or none of the "{}" '
                                      'must be specified'.format(quote_join_sorted(field_names)))

        required_oneofs = self.get_required_oneofs()
        for field_names in required_oneofs:
            n = list(map(self._is_field_present, field_names)).count(True)
            if n == 0:
                raise ValidationError('at least one of the "{}" '
                                      'must be specified'.format(quote_join_sorted(field_names)))
            elif n > 1:
                raise ValidationError('at most one of the "{}" '
                                      'must be specified'.format(quote_join_sorted(field_names)))

    def wrap_composite_fields(self):
        for field_name, field_type in self.__composite_field_type_items__:
            attr_name = self.PB_FIELD_TO_CLS_ATTR_MAPPING.get(field_name, field_name)
            if field_type == TYPE_MESSAGE_MAP or field_type == TYPE_MESSAGE_ORDERED_MAP:
                field_value = getattr(self.pb, field_name)
                attr_value = [(entry_pb.key, wrap(entry_pb.value)) for entry_pb in field_value]
            elif field_type == TYPE_CUSTOM_SCALAR_MAP:
                field_value = getattr(self.pb, field_name)
                attr_value = [wrap(entry_pb) for entry_pb in field_value]
            elif field_type == TYPE_MESSAGE_SEQ:
                field_value = getattr(self.pb, field_name)
                attr_value = [wrap(m) for m in field_value]
            elif field_type == TYPE_MESSAGE:
                if self.pb.HasField(field_name):
                    field_value_pb = getattr(self.pb, field_name)
                    attr_value = wrap(field_value_pb)
                else:
                    attr_value = None
            else:
                raise AssertionError()
            setattr(self, attr_name, attr_value)

    def auto_raise_if_blacklisted(self, ctx=DEFAULT_CTX):
        if self.MODULE_NAME in ctx.modules_blacklist:
            raise ValidationError('using {} is not allowed. Please contact support if you absolutely must do it.'
                                  .format(self.MODULE_NAME))

    def wrap_calls_and_knobs_and_certs(self):
        self._calls = {}  # type: dict[str, ConfigWrapperBase]
        self._knobs = {}  # type: dict[str, ConfigWrapperBase]
        self._certs = {}  # type: dict[str, ConfigWrapperBase]

        for f_field_name in self.__call_fields__:
            if self.pb.HasField(f_field_name):
                f_field_value = getattr(self.pb, f_field_name)
                field_name = f_field_name[2:]  # strip "f_" prefix
                self._calls[field_name] = wrap(f_field_value)

        for k_field_name in self.__knob_fields__:
            if self.pb.HasField(k_field_name):
                k_field_value = getattr(self.pb, k_field_name)
                field_name = k_field_name[2:]  # strip "k_" prefix
                self._knobs[field_name] = wrap(k_field_value)

        for c_field_name in self.__cert_fields__:
            if self.pb.HasField(c_field_name):
                c_field_value = getattr(self.pb, c_field_name)
                field_name = c_field_name[2:]  # strip "c_" prefix
                self._certs[field_name] = wrap(c_field_value)

    def _check_pb_cls(self, pb):
        if self.__protobuf__:
            if pb.DESCRIPTOR.full_name != self.__protobuf__.DESCRIPTOR.full_name:
                raise AssertionError('{} is not {}'.format(type(pb), self.__protobuf__))

    def update_pb(self, pb=None):
        if pb is not None:
            self._check_pb_cls(pb)
            self.pb = pb
        self.wrap_calls_and_knobs_and_certs()
        self.wrap_composite_fields()

    @staticmethod
    def require_value(value, name=None):
        if not value:
            raise ValidationError('is required', field_name=name)

    def get(self, field_name, default_value=None):
        """
        :param str field_name: name of the field to look up for
        :param default_value: a value to return if required value is not found
                              neither in its field nor in its f_ or k_-prefixed fields
        :rtype: awacs.wrappers.util.Value
        :raises: ValidationError
        """
        has_field = bool(getattr(self.pb, field_name))
        has_func_field = field_name in self._calls
        has_knob_field = field_name in self._knobs
        has_cert_field = field_name in self._certs
        if has_field and has_func_field:
            raise ValidationError('field {} assigned both a value and a function call'.format(field_name))
        if has_field and has_knob_field:
            raise ValidationError('field {} assigned both a regular and a knob value'.format(field_name))
        if has_field and has_cert_field:
            raise ValidationError('field {} assigned both a regular and a cert value'.format(field_name))
        if has_func_field and has_knob_field:
            raise ValidationError('field {} assigned both a func and a knob value'.format(field_name))
        if has_func_field and has_cert_field:
            raise ValidationError('field {} assigned both a func and a cert value'.format(field_name))
        if has_func_field:
            call = self._calls[field_name]
            call.validate()
            return Value(Value.CALL, call)
        elif has_knob_field:
            knob = self._knobs[field_name]
            knob.validate()
            return Value(Value.KNOB, knob)
        elif has_cert_field:
            cert = self._certs[field_name]
            cert.validate()
            return Value(Value.CERT, cert)
        else:
            value = getattr(self.pb, field_name)
            if not value and default_value is not None:
                return Value(Value.VALUE, default_value)
            else:
                return Value(Value.VALUE, value)

    def validate_knob(self, field_name, ctx=DEFAULT_CTX):
        with validate(field_name):
            value = self.get(field_name)
            if value.type == value.KNOB:
                required_type = self.ALLOWED_KNOBS[field_name]
                ctx.validate_knob(value.value.id, required_type)
            elif value.type == value.CALL:
                validate_func_name_one_of(value.value, defs.get_its_control_path.name)

    def validate_cert(self, field_name, ctx=DEFAULT_CTX):
        with validate(field_name):
            value = self.get(field_name)
            if value.type == value.CERT:
                if field_name not in self.ALLOWED_CERTS:
                    raise AssertionError('Certificates not allowed inside field "{}"'.format(field_name))
                ctx.validate_cert(value.value.id)
            elif value.type == value.CALL:
                validate_func_name_one_of(value.value, (defs.get_public_cert_path.name,
                                                        defs.get_private_cert_path.name))
            elif value.type != value.VALUE:
                raise AssertionError('Unsupported value type inside certificate field {}: {}'.format(field_name,
                                                                                                     value.type))

    def walk_composite_fields(self):
        for attr_name, field_type in self.__composite_attr_items__:
            value_pb = getattr(self.pb, attr_name, None)
            if field_type == TYPE_MESSAGE and value_pb.DESCRIPTOR.full_name == AWACS_HOLDER_MESSAGE_FULL_NAME:
                continue

            value = getattr(self, attr_name)
            if value is None:
                continue

            if field_type == TYPE_MESSAGE:
                wrappers = [value]
            elif field_type == TYPE_MESSAGE_MAP:
                wrappers = six.itervalues(value)
            elif field_type == TYPE_MESSAGE_ORDERED_MAP:
                wrappers = [v for _, v in value]
            elif field_type == TYPE_MESSAGE_SEQ:
                wrappers = list(value)
            else:
                raise AssertionError()

            for w in wrappers:
                if isinstance(w, Holder):
                    continue
                if not isinstance(w, ConfigWrapperBase):
                    continue
                yield w
                for item in w.walk_composite_fields():
                    yield item


class ConfigWrapperBase(WrapperBase):
    def get_included_knob_ids(self, namespace_id, ctx):
        """
        :type namespace_id: str
        :type ctx: ValidationCtx
        :rtype: set[(str, str)]
        """
        rv = set()
        for w in self.walk_composite_fields():
            rv.update(w.get_included_knob_ids(namespace_id, ctx))
        rv.update((namespace_id, knob.id) for knob in six.itervalues(self._knobs))

        if ctx.are_knobs_enabled():
            for field_name in self.ALLOWED_KNOBS:
                value = self.get(field_name)
                if not value.value:
                    knob_id = self.DEFAULT_KNOB_IDS.get(field_name)
                    if knob_id:
                        rv.add((namespace_id, knob_id))
        return rv

    def get_included_cert_ids(self, namespace_id, ctx):
        """
        :type namespace_id: str
        :type ctx: ValidationCtx
        :rtype: set[(str, str)]
        """
        rv = set()
        for w in self.walk_composite_fields():
            rv.update(w.get_included_cert_ids(namespace_id, ctx))
        rv.update((namespace_id, cert.id) for cert in six.itervalues(self._certs))
        return rv

    def get_would_be_included_full_knob_ids(self, namespace_id, ctx):
        """
        :type namespace_id: str
        :type ctx: ValidationCtx
        :rtype: set[(str, str)]
        """
        return self.get_included_knob_ids(namespace_id, ctx)

    def get_would_be_included_full_cert_ids(self, namespace_id, ctx):
        """
        :type namespace_id: str
        :type ctx: ValidationCtx
        :rtype: set[(str, str)]
        """
        return self.get_included_cert_ids(namespace_id, ctx)

    def fill_default_knobs(self, ctx):
        """
        :type ctx: ValidationCtx
        """
        assert ctx.are_knobs_enabled()
        for field_name in self.ALLOWED_KNOBS:
            value = self.get(field_name)
            if not value.value:
                knob_id = self.DEFAULT_KNOB_IDS.get(field_name)
                if knob_id:
                    knob_ref_pb = getattr(self.pb, 'k_' + field_name)
                    knob_ref_pb.id = knob_id
                    knob_ref_pb.optional = True
        self.wrap_calls_and_knobs_and_certs()

    def include_knobs(self, namespace_id, balancer_id, knob_spec_pbs, ctx):
        """
        :type namespace_id: str
        :type balancer_id: str
        :type knob_spec_pbs: dict[(str, str), model_pb2.KnobSpec]
        :type ctx: ValidationCtx
        :rtype: set[(str, str)]
        """
        if ctx.are_knobs_enabled():
            self.fill_default_knobs(ctx=ctx)

        rv = set()
        for field_name, knob in six.iteritems(self._knobs):
            full_knob_id = (namespace_id, knob.id)
            flat_knob_id = flatten_full_id(namespace_id, full_knob_id)
            if full_knob_id in knob_spec_pbs:
                knob_spec_pb = knob_spec_pbs[full_knob_id]
                self.validate_knob(field_name, ctx=ctx)
                if knob_spec_pb.mode == knob_spec_pb.WATCHED:
                    if balancer_id not in knob_spec_pb.its_watched_state.its_location_paths:
                        raise KnobDoesNotExist('knob "{}" does not match balancer "{}"'.format(
                            flat_knob_id, balancer_id))
                    filename = knob_spec_pb.its_watched_state.filename
                else:
                    filename = knob.id
                call_pb = getattr(self.pb, 'f_' + field_name)
                call_pb.type = call_pb.GET_ITS_CONTROL_PATH
                call_pb.get_its_control_path_params.filename = filename
                self.pb.ClearField('k_' + field_name)
                rv.add(full_knob_id)
            elif knob.optional:
                self.pb.ClearField('k_' + field_name)
            else:
                raise KnobDoesNotExist('knob "{}" is missing'.format(flat_knob_id))
        self.wrap_calls_and_knobs_and_certs()
        return rv

    def includes_knobs(self, ctx):
        """
        :type ctx: ValidationCtx
        :rtype: bool
        """
        return bool(self.get_included_knob_ids('', ctx))

    def would_include_knobs(self, ctx):
        """
        :type ctx: ValidationCtx
        :rtype: bool
        """
        return self.includes_knobs(ctx)

    def includes_certs(self, ctx):
        """
        :rtype: bool
        """
        return bool(self.get_included_cert_ids('', ctx))

    def validate_composite_fields(self, ctx=DEFAULT_CTX, preceding_modules=()):
        ctx.maybe_gevent_idle()

    def validate(self, ctx=DEFAULT_CTX, preceding_modules=()):
        self.validate_composite_fields(ctx=ctx, preceding_modules=preceding_modules)

    def includes_domains(self):
        return False

    def walk_chain(self, visit_branches=False):
        yield self

    def to_config(self, ctx=DEFAULT_CTX, preceding_modules=()):
        raise NotImplementedError

    def expand_immediate_contained_macro(self):
        raise NotImplementedError


class ModuleWrapperBase(ConfigWrapperBase):
    def walk_chain(self, visit_branches=False):
        for m in super(ModuleWrapperBase, self).walk_chain(visit_branches=visit_branches):
            yield m
        if visit_branches:
            for branch in self.get_branches():
                for m in branch.walk_chain(visit_branches=visit_branches):
                    yield m

    def get_branches(self):
        return ()

    def get_named_branches(self):
        return {}

    def includes_upstreams(self):
        """
        :rtype: bool
        """
        return False

    def includes_backends(self):
        """
        :rtype: bool
        """
        return False

    def would_include_backends(self):
        """
        :rtype: bool
        """
        return self.includes_backends()

    def would_include_weight_sections(self):
        """
        :rtype: bool
        """
        return False

    def would_include_internal_upstreams(self):
        """
        :rtype: bool
        """
        return False

    def is_chainable(self):
        """
        :rtype: bool
        """
        return False

    def to_normal_form_XXX(self):
        """
        See https://st.yandex-team.ru/AWACS-1044 for details.
        This code is not production grade and is intended to be used by tooling around awacs (such as awacsemtool).
        """
        for b in self.get_branches():
            b.to_normal_form_XXX()

    def expand_macroses(self, ctx=DEFAULT_CTX, preceding_modules=()):
        new_preceding_modules = add_module(preceding_modules, self)
        for b in self.get_branches():
            b.expand_macroses(ctx=ctx, preceding_modules=new_preceding_modules)


class ChainableModuleWrapperBase(ModuleWrapperBase):
    __slots__ = ('nested',)

    # chainable modules are generally shareable.
    # exceptions include such classes as IpdispatchSection, RegexpSection, Balancer2Backend,
    # which aren't really modules
    SHAREABLE = True

    def to_normal_form_XXX(self):
        """
        See https://st.yandex-team.ru/AWACS-1044 for details.
        This code is not production grade and is intended to be used by tooling around awacs (such as awacsemtool).
        """
        if self.nested:
            self.nested.to_normal_form_XXX()
        for b in self.get_branches():
            b.to_normal_form_XXX()

    def expand_macroses(self, ctx=DEFAULT_CTX, preceding_modules=()):
        new_preceding_modules = add_module(preceding_modules, self)
        if self.nested:
            self.nested.expand_macroses(ctx=ctx, preceding_modules=new_preceding_modules)
        for b in self.get_branches():
            b.expand_macroses(ctx=ctx, preceding_modules=new_preceding_modules)

    def validate_composite_fields(self, ctx=DEFAULT_CTX, preceding_modules=(), chained_modules=()):
        super(ChainableModuleWrapperBase, self).validate_composite_fields(ctx=ctx, preceding_modules=preceding_modules)
        if self.nested:
            self.nested.validate(ctx=ctx, preceding_modules=add_module(preceding_modules, self))

    def require_nested(self, chained_modules):
        if not chained_modules and not self.nested:
            raise ValidationError('must have nested module')

    def _get_first_chained_module(self, chained_modules):
        if self.nested:
            for module in self.nested.walk_chain():
                return module
        elif chained_modules:
            first_chained_module_holder = chained_modules[0]
            assert isinstance(first_chained_module_holder, Holder)
            return first_chained_module_holder.module
        else:
            return None

    def _get_last_chained_module(self, chained_modules):
        if self.nested:
            for module in self.nested.walk_chain():
                last_chained_module = module
        elif chained_modules:
            last_chained_module_holder = chained_modules[-1]
            assert isinstance(last_chained_module_holder, Holder)
            last_chained_module = last_chained_module_holder.module
        else:
            last_chained_module = None
        return last_chained_module

    def validate(self, ctx=DEFAULT_CTX, preceding_modules=(), chained_modules=()):
        self.require_nested(chained_modules)
        self.validate_composite_fields(ctx=ctx,
                                       preceding_modules=preceding_modules,
                                       chained_modules=chained_modules)

    def _add_nested_module_to_config(self, config, ctx=DEFAULT_CTX, preceding_modules=()):
        if self.nested:
            nested_config = self.nested.to_config(ctx=ctx, preceding_modules=add_module(preceding_modules, self))
            config.extend(nested_config)
            config.outlets = nested_config.outlets

    def _to_params_config(self, preceding_modules=(), *args, **kwargs):
        return Config()

    def to_config(self, ctx=DEFAULT_CTX, preceding_modules=(), *args, **kwargs):
        config = self._to_params_config(ctx=ctx, preceding_modules=preceding_modules, *args, **kwargs)
        self._add_nested_module_to_config(config, ctx=ctx, preceding_modules=preceding_modules)
        config.shareable = self.SHAREABLE
        return config

    def walk_chain(self, visit_branches=False):
        for m in super(ChainableModuleWrapperBase, self).walk_chain(visit_branches=visit_branches):
            yield m
        if self.nested:
            for m in self.nested.walk_chain(visit_branches=visit_branches):
                yield m

    def is_chainable(self):
        return True


class Chain(object):
    __slots__ = ('modules',)

    def __init__(self, modules=None):
        self.modules = modules or []  # type: list[Holder]

    def validate(self, ctx=DEFAULT_CTX, preceding_modules=()):
        preceding_modules = list(preceding_modules)
        for i, module_holder in enumerate(self.modules):
            is_last = (i == len(self.modules) - 1)
            if module_holder.is_empty():
                raise ValidationError('chain can not contain empty holders', 'modules[{}]'.format(i))
            it = module_holder.walk_chain()
            module = next(it)  # skip module itself
            try:
                if not is_last and not module.is_chainable():
                    raise ValidationError(
                        'non-chainable module {} can not be chained'.format(module.__class__.__name__))
                try:
                    next(it)
                except StopIteration:
                    pass
                else:
                    if not is_last:
                        raise ValidationError('only last module in chain can have nested modules')
                kwargs = {
                    'ctx': ctx,
                    'preceding_modules': preceding_modules,
                }
                if not is_last:
                    kwargs['chained_modules'] = self.modules[i + 1:]
                module_holder.validate(**kwargs)
            except ValidationError as e:
                append_field_name_to_validation_error(e, 'modules[{}]'.format(i))
                raise
            preceding_modules.append(module_holder.module)

    def walk_chain(self, visit_branches=False):
        for m in self.modules[0].walk_chain(visit_branches=visit_branches):
            yield m
        for m in itertools.islice(self.modules, 1, None):  # [1:]
            for m_ in m.walk_chain(visit_branches=visit_branches):
                yield m_

    @classmethod
    def expand_modules(cls, modules, ctx=DEFAULT_CTX, preceding_modules=()):
        modules = list(modules)
        i = 0
        preceding_modules = list(preceding_modules)
        while i < len(modules):
            holder = modules[i]
            if i != len(modules) and hasattr(holder.module, 'expand'):
                modules[i:i + 1] = list(map(Holder, holder.module.expand(ctx=ctx, preceding_modules=preceding_modules)))
            else:
                i += 1
            preceding_modules.append(holder.module)
        return modules

    def expand_macroses(self, ctx=DEFAULT_CTX, preceding_modules=()):
        modules = list(self.modules)
        i = 0
        preceding_modules = list(preceding_modules)
        while i < len(modules):
            holder = modules[i]
            if i != len(modules):
                holder.expand_macroses(ctx=ctx, preceding_modules=preceding_modules)
                if holder.chain:
                    modules[i:i + 1] = holder.chain.modules
                else:
                    i += 1
            else:
                i += 1
            if holder.module:
                preceding_modules.append(holder.module)
        return modules

    def to_config(self, ctx=DEFAULT_CTX, preceding_modules=()):
        preceding_modules = list(preceding_modules)
        expanded_modules = self.expand_modules(self.modules, ctx=ctx, preceding_modules=preceding_modules)
        first_module = expanded_modules[0]
        rv = first_module.module.to_config(ctx=ctx, preceding_modules=preceding_modules)
        if not isinstance(rv, tuple):
            rv = (first_module.module_name, rv)
        module_name, config = rv
        preceding_modules.append(first_module)
        for holder in itertools.islice(expanded_modules, 1, None):
            module_config = holder.module.to_config(ctx=ctx, preceding_modules=preceding_modules)
            if isinstance(module_config, tuple):
                module_name, module_config = module_config
            else:
                module_name = holder.module_name
            for outlet in config.get_outlets():
                outlet.table[module_name] = module_config
            config = module_config
            preceding_modules.append(holder.module)
        rv[1].outlets = config.get_outlets()
        return rv

    def shift(self):
        return self.modules[0].module, Chain(self.modules[1:])

    def __repr__(self):
        return 'Chain({})'.format([holder.module_name for holder in self.modules])


class MacroBase(object):
    def expand(self, ctx=DEFAULT_CTX, preceding_modules=()):
        raise NotImplementedError

    def is_chainable(self):
        return False

    def would_include_backends(self):
        return False

    def get_would_be_included_full_backend_ids(self, current_namespace_id):
        return set()

    def would_include_weight_sections(self):
        return True

    def would_include_internal_upstreams(self):
        return True

    def get_would_be_included_full_weight_section_ids(self, namespace_id):
        return set()

    def get_would_be_included_full_internal_upstream_ids(self, namespace_id):
        return set()


class Holder(ConfigWrapperBase):
    """
    :type module_name: Optional[six.text_type]
    :type module: Optional[Union[ModuleWrapperBase, ChainableModuleWrapperBase]]
    :type chain_name: Optional[six.text_type]
    :type chain: Optional[Chain]
    """
    __slots__ = ('module_name', 'module', 'chain_name', 'chain')

    __protobuf__ = modules_pb2.Holder

    def __init__(self, pb=None):
        self.module_name = None
        self.module = None
        self.chain_name = None
        self.chain = None
        super(ConfigWrapperBase, self).__init__(pb=pb)

    def wrap_composite_fields(self):
        self.module_name = None
        self.module = None
        self.chain_name = None
        self.chain = None
        module_name = self.pb.WhichOneof('module')
        if module_name:
            self.module_name = module_name
            self.module = wrap(getattr(self.pb, module_name))
        if self.pb.modules:
            chained_modules = []
            for holder in self.pb.modules:
                chained_modules.append(Holder(holder))
            self.chain_name = chained_modules[0].module_name
            self.chain = Chain(chained_modules)

    def validate(self, ctx=DEFAULT_CTX, preceding_modules=(), chained_modules=()):
        if self.module and self.chain:
            raise ValidationError('both "{}" module and "modules" list specified'.format(self.module_name))
        if not self.module and not self.chain:
            raise ValidationError('module is not specified')

        if self.module:
            try:
                if isinstance(self.module, ChainableModuleWrapperBase):
                    self.module.validate(ctx=ctx,
                                         preceding_modules=preceding_modules,
                                         chained_modules=chained_modules)
                elif isinstance(self.module, ModuleWrapperBase):
                    self.module.validate(ctx=ctx,
                                         preceding_modules=preceding_modules)
                else:
                    raise AssertionError('unexpected module type: {}'.format(self.module.__class__))
            except ValidationError as e:
                append_field_name_to_validation_error(e, self.module_name)
                raise
        else:
            self.chain.validate(ctx=ctx, preceding_modules=preceding_modules)

    def to_config(self, ctx=DEFAULT_CTX, preceding_modules=()):
        table = {}
        if self.module:
            name = self.module_name
            config = self.module.to_config(ctx=ctx, preceding_modules=preceding_modules)
        else:
            name = self.chain_name
            config = self.chain.to_config(ctx=ctx, preceding_modules=preceding_modules)
        if isinstance(config, tuple):
            name, config = config
        table[name] = config
        return Config(table, outlets=config.outlets)

    def walk_chain(self, visit_branches=False):
        for m in (self.module or self.chain).walk_chain(visit_branches=visit_branches):
            yield m

    def is_empty(self):
        """
        Returns True if this holder contains either module or chain and False otherwise.
        :rtype: bool
        """
        first_module = next(self.walk_chain(), False)
        return not first_module

    def __repr__(self):
        if self.is_empty():
            return 'Holder(-)'
        else:
            return 'Holder({})'.format(self.module_name or 'chain#{}'.format(len(self.chain.modules)))

    def attach(self, holder_pbs):
        last_module = None
        for module in self.walk_chain(visit_branches=False):
            last_module = module
        if not last_module:
            raise RuntimeError('Called .attach() on an empty holder')
        last_module.pb.nested.modules.extend(holder_pbs)
        last_module.update_pb()

    def to_normal_form_XXX(self):
        """
        See https://st.yandex-team.ru/AWACS-1044 for details.
        This code is not production grade and is intended to be used by tooling around awacs (such as awacsemtool).
        """
        if self.chain:
            root_holder_pb = modules_pb2.Holder()
            curr_holder_pb = root_holder_pb

            n = len(self.chain.modules)
            for i, module_holder in enumerate(self.chain.modules):
                # do not allow chains in chain
                assert module_holder.module
                assert not module_holder.chain

                module_holder.to_normal_form_XXX()
                curr_holder_pb.CopyFrom(module_holder.pb)

                if not isinstance(module_holder.module, ChainableModuleWrapperBase):
                    assert i == n - 1
                    continue

                # assert not module_holder.module.nested
                curr_holder_pb = get_holder_pb_module(curr_holder_pb).nested

            self.pb.CopyFrom(root_holder_pb)
            self.update_pb()
        else:
            self.module.to_normal_form_XXX()

    def expand_immediate_contained_macro(self, ctx=DEFAULT_CTX, preceding_modules=()):
        """
        :param ctx:
        :param preceding_modules:
        :return:
        """
        macro = self.module
        expanded_macro_module_pbs = macro.expand(ctx=ctx, preceding_modules=preceding_modules)

        assert all(holder_pb.DESCRIPTOR.full_name == AWACS_HOLDER_MESSAGE_FULL_NAME
                   for holder_pb in expanded_macro_module_pbs)

        if isinstance(macro, ChainableModuleWrapperBase) and macro.nested:
            nested = macro.nested
            if nested.module:
                expanded_macro_module_pbs.append(nested.pb)
            else:
                expanded_macro_module_pbs.extend([m.pb for m in nested.chain.modules])

        head_pb = expanded_macro_module_pbs[0]
        rest_pbs = expanded_macro_module_pbs[1:]
        self.pb.CopyFrom(head_pb)
        self.update_pb()

        if rest_pbs:
            if isinstance(self.module, ChainableModuleWrapperBase):
                self.attach(rest_pbs)
            if hasattr(self.module, 'sections'):
                for section in six.itervalues(self.module.sections):
                    section.nested.attach(rest_pbs)

    def expand_contained_macro(self, ctx=DEFAULT_CTX, preceding_modules=()):
        macro = self.module
        self.expand_immediate_contained_macro(ctx=ctx, preceding_modules=preceding_modules)
        self.expand_macroses(ctx=ctx, preceding_modules=add_module(preceding_modules, macro))

    def expand_macroses(self, ctx=DEFAULT_CTX, preceding_modules=()):
        if self.module:
            if isinstance(self.module, MacroBase):
                self.expand_contained_macro(ctx=ctx, preceding_modules=preceding_modules)
            else:
                self.module.expand_macroses(ctx=ctx, preceding_modules=preceding_modules)
        elif self.chain:
            self.chain.expand_macroses(ctx=ctx, preceding_modules=preceding_modules)

    def validate_shared_and_report_refs(self, validation_ctx=DEFAULT_CTX):
        from awacs.wrappers.main import Shared, Report

        shared_anchors = set()
        shared_refs = set()

        report_anchors = set()
        report_refs = set()

        chains = [self.walk_chain()]

        duplicate_shared_uuids_are_possible = False

        for chain in chains:
            module = None
            chain_shared_uuids = []
            for module in chain:
                if isinstance(module, Shared):
                    chain_shared_uuids.append(module.pb.uuid)
                if isinstance(module, Report):
                    if module.pb.uuid:
                        report_anchors.add(module.pb.uuid)
                    if module.pb.refers:
                        report_refs.update(module.pb.refers.split(','))
                for branch in module.get_branches():
                    chains.append(branch.walk_chain())
            last_module = module
            if isinstance(last_module, Shared):
                last_shared_uuid = chain_shared_uuids.pop()
                shared_refs.add(last_shared_uuid)

            for uuid in chain_shared_uuids:
                if uuid in shared_anchors:
                    duplicate_shared_uuids_are_possible = True
                else:
                    shared_anchors.add(uuid)

        if duplicate_shared_uuids_are_possible:
            ctx = self.to_config(ctx=validation_ctx).compute_context()

            config_uuid_by_user_shared_uuids = collections.defaultdict(set)
            for node in ctx.tree.bfs():
                if node.config and node.config.is_shared() and node.config.is_shared_anchor():
                    config_uuid_by_user_shared_uuids[node.config.shared_uuid].add(node.id)

            duplicate_shared_uuids = set()
            for shared_uuid, config_uuids in six.iteritems(config_uuid_by_user_shared_uuids):
                if len(config_uuids) > 1:
                    duplicate_shared_uuids.add(shared_uuid)
            if duplicate_shared_uuids:
                raise ValidationError('the following shared uuids defined more than once: "{}"'.format(
                    quote_join_sorted(duplicate_shared_uuids)))

        dangling_shared_refs = sorted(shared_refs - shared_anchors)
        if dangling_shared_refs:
            raise ValidationError(
                'the following referenced "shared" uuids do not exist: "{}"'.format('", "'.join(dangling_shared_refs)))

        dangling_report_refs = sorted(report_refs - report_anchors)
        if dangling_report_refs:
            raise ValidationError(
                'the following referenced "report" uuids do not exist: "{}"'.format('", "'.join(dangling_report_refs)))

    def is_l7_macro(self):
        return self.module_name == 'l7_macro'

    def is_l7_upstream_macro(self):
        return self.module_name == 'l7_upstream_macro'

    def get_nested_module(self):
        return next(self.walk_chain(), None)
