# coding: utf-8
import contextlib

import six
import yaml
from google.protobuf import descriptor as protobufdescriptor, message as protobufmessage

from awacs.lib import OrderedDict
from awacs.lib.yamlparser import parser, builder, errors
from awacs.wrappers import funcs, luaparser
from awacs.wrappers.errors import ValidationError
from . import util


FD = protobufdescriptor.FieldDescriptor
VALUE_FIELD_NAMES = frozenset(('google.protobuf.BoolValue',
                               'google.protobuf.FloatValue',
                               'google.protobuf.Int32Value',
                               'google.protobuf.StringValue'))


class Builder(builder.ProtobufBuilder):
    DEFAULT_FUNC_YAML_TAG = '!f'
    FUNC_YAML_TAGS = (DEFAULT_FUNC_YAML_TAG, '!func')

    DEFAULT_KNOB_YAML_TAG = '!k'
    KNOB_YAML_TAGS = (DEFAULT_KNOB_YAML_TAG, '!knob')

    DEFAULT_CERT_YAML_TAG = '!c'
    CERT_YAML_TAGS = (DEFAULT_CERT_YAML_TAG, '!cert')

    def _on_func_call(self, event):
        if event.tag not in self.FUNC_YAML_TAGS:
            raise errors.SchemaError('expected function call')

        call_pb, _ = self._pb_stack.pop()
        try:
            call_obj = luaparser.parse_call(event.value)
            call_pb.SetInParent()
            funcs.raw_call_to_call_pb(call_obj, call_pb)
        except (ValueError, ValidationError) as e:
            raise errors.SchemaError('invalid function call expression: {}'.format(getattr(e, 'message', None) or str(e)))

    def _on_knob(self, event):
        if event.tag not in self.KNOB_YAML_TAGS:
            raise errors.SchemaError('expected knob')

        knob_pb, _ = self._pb_stack.pop()
        knob_pb.SetInParent()
        knob_pb.id = event.value

    def _on_cert(self, event):
        if event.tag not in self.CERT_YAML_TAGS:
            raise errors.SchemaError('expected cert')

        cert_pb, _ = self._pb_stack.pop()
        cert_pb.SetInParent()
        cert_pb.id = event.value

    def _on_scalar_mapping_value(self, event):
        """
        :type event: yaml.ScalarEvent
        """
        if event.tag in self.FUNC_YAML_TAGS:
            key, key_desc = self._pb_stack.pop()
            mapping_pb, mapping_desc = self._pb_stack[-1]
            if not self.is_ordered_map_entry(mapping_desc):
                raise errors.SchemaError('map value does not accept a function call')
            entry_pb = mapping_pb.entries_pb.add(key=key)
            field_value, field_desc = self.get_func_field(entry_pb, 'value')
            self._pb_stack.append((field_value, field_desc.message_type))
            self._on_func_call(event)
        elif event.tag in self.KNOB_YAML_TAGS:
            raise errors.SchemaError('map value does not allow knobs')
        elif event.tag in self.CERT_YAML_TAGS:
            raise errors.SchemaError('map value does not allow certs')
        else:
            return super(Builder, self)._on_scalar_mapping_value(event)

    def _on_scalar_message_field_value(self, event):
        if event.tag in self.FUNC_YAML_TAGS:
            field_name, field_desc = self._pb_stack.pop()  # pop current field
            message_pb, message_desc = self._pb_stack[-1]  # get current message
            field_value, field_desc = self.get_func_field(message_pb, field_name)
            self._pb_stack.append((field_value, field_desc.message_type))
            self._on_func_call(event)
        elif event.tag in self.KNOB_YAML_TAGS:
            field_name, field_desc = self._pb_stack.pop()  # pop current field
            message_pb, message_desc = self._pb_stack[-1]  # get current message
            field_value, field_desc = self.get_knob_field(message_pb, field_name)
            self._pb_stack.append((field_value, field_desc.message_type))
            self._on_knob(event)
        elif event.tag in self.CERT_YAML_TAGS:
            field_name, field_desc = self._pb_stack.pop()  # pop current field
            message_pb, message_desc = self._pb_stack[-1]  # get current message
            field_value, field_desc = self.get_cert_field(message_pb, field_name)
            self._pb_stack.append((field_value, field_desc.message_type))
            self._on_cert(event)
        else:
            return super(Builder, self)._on_scalar_message_field_value(event)

    @classmethod
    def get_embedded_field_name(cls, message_desc):
        if message_desc.full_name in VALUE_FIELD_NAMES:
            return 'value'
        else:
            awacs_options = cls.get_awacs_message_options(message_desc)
            return awacs_options and awacs_options.embed if awacs_options else None

    def _handle_message(self, message_pb):
        message_desc = message_pb.DESCRIPTOR
        embedded_field_name = self.get_embedded_field_name(message_desc)

        if message_desc.name == 'Call':
            # pretend that we parsed a message
            self._pb_stack.append((message_pb, message_desc))
            # and now we expect a !f-tagged scalar
            self._stack.extend(reversed([yaml.ScalarEvent, '_on_func_call']))
        elif message_desc.name == 'KnobRef':
            # pretend that we parsed a message
            self._pb_stack.append((message_pb, message_desc))
            # and now we expect a !k-tagged scalar
            self._stack.extend(reversed([yaml.ScalarEvent, '_on_knob']))
        elif message_desc.name == 'CertRef':
            # pretend that we parsed a message
            self._pb_stack.append((message_pb, message_desc))
            # and now we expect a !c-tagged scalar
            self._stack.extend(reversed([yaml.ScalarEvent, '_on_cert']))
        elif embedded_field_name:
            # pretend that we:
            # 1. parsed a message:
            message_pb.SetInParent()
            self._pb_stack.append((message_pb, message_desc))
            self._stack.append('_on_message_end')
            self._seen_fields_stack.append(set())
            # 2. parsed an embedded field:
            embedded_field_desc = message_desc.fields_by_name[embedded_field_name]
            self._handle_message_field(embedded_field_name,
                                       getattr(message_pb, embedded_field_name),
                                       embedded_field_desc)
        else:
            super(Builder, self)._handle_message(message_pb)

    @classmethod
    def get_field_value(cls, message_pb, field_name):
        message_desc = message_pb.DESCRIPTOR
        fields = message_desc.fields_by_name
        if field_name in fields:
            return getattr(message_pb, field_name)
        for name in cls.get_squashed_fields(message_desc):
            value = cls.get_field_value(getattr(message_pb, name), field_name)
            if value is not None:
                return value

    @classmethod
    def get_field_desc(cls, message_pb, field_name):
        message_desc = message_pb.DESCRIPTOR
        fields = message_desc.fields_by_name
        if field_name in fields:
            return fields[field_name]
        for name in cls.get_squashed_fields(message_desc):
            desc = cls.get_field_desc(getattr(message_pb, name), field_name)
            if desc is not None:
                return desc

    @classmethod
    def get_which_oneof(cls, message_pb, oneof_name):
        message_desc = message_pb.DESCRIPTOR
        if oneof_name in message_desc.oneofs_by_name:
            # try to locate oneof in our message
            oneof_key = super(Builder, cls).get_which_oneof(message_pb, oneof_name)
            if oneof_key is not None:
                return oneof_key
        else:
            # if failed, try to locate it in squashed ones
            for name in cls.get_squashed_fields(message_desc):
                oneof_key = cls.get_which_oneof(getattr(message_pb, name), oneof_name)
                if oneof_key is not None:
                    return oneof_key
        return None

    @classmethod
    def is_ordered_map_entry(cls, message_desc):
        awacs_options = cls.get_awacs_message_options(message_desc)
        return awacs_options and awacs_options.map_entry

    @classmethod
    def get_awacs_message_options(cls, desc):
        for option, value in desc.GetOptions().ListFields():
            if option.name == 'awacs_message':
                return value

    @classmethod
    def get_awacs_field_options(cls, field_desc):
        for option, value in field_desc.GetOptions().ListFields():
            if option.name == 'awacs_field':
                return value

    @classmethod
    def get_squashed_fields(cls, message_desc):
        for name, field_desc in six.iteritems(message_desc.fields_by_name):
            field_awacs_options = cls.get_awacs_field_options(field_desc)
            if field_awacs_options and field_awacs_options.squashed:
                yield name

    @classmethod
    def get_knob_field(cls, message_pb, field_name):
        message_fields = message_pb.DESCRIPTOR.fields_by_name
        knob_field_name = 'k_' + field_name
        if knob_field_name not in message_fields:
            raise errors.SchemaError('field does not accept a knob value'.format(field_name))
        return getattr(message_pb, knob_field_name), message_fields[knob_field_name]

    @classmethod
    def get_cert_field(cls, message_pb, field_name):
        message_fields = message_pb.DESCRIPTOR.fields_by_name
        cert_field_name = 'c_' + field_name
        if cert_field_name not in message_fields:
            raise errors.SchemaError('field does not accept a certificate value'.format(field_name))
        return getattr(message_pb, cert_field_name), message_fields[cert_field_name]

    @classmethod
    def get_func_field(cls, message_pb, field_name):
        message_fields = message_pb.DESCRIPTOR.fields_by_name
        func_field_name = 'f_' + field_name
        if func_field_name not in message_fields:
            raise errors.SchemaError('field does not accept a function call'.format(field_name))
        return getattr(message_pb, func_field_name), message_fields[func_field_name]


def parse(cls, document, ensure_ascii=False):
    b = Builder(cls)
    with contextlib.closing(parser.parse(document, ensure_ascii=ensure_ascii)) as events_gen:
        parser.feed_events(b, events_gen)
    return b.get_result()


def iter_pb_map_items(pb_field_value, pb_field_desc):
    entry_message_desc = pb_field_desc.message_type
    if Builder.is_ordered_map_entry(entry_message_desc):
        # message is "ordered map entry" when it has the following option:
        # (awacs_message_schema.awacs).map_entry = true;
        if 'f_value' in entry_message_desc.fields_by_name:
            for entry_pb in pb_field_value:
                if entry_pb.HasField('f_value'):
                    call_expr = luaparser.dump_call(funcs.call_pb_to_raw_call(entry_pb.f_value))
                    value = util.FTag(call_expr)
                else:
                    value = entry_pb.value
                yield (entry_pb.key, value)
        else:
            for entry_pb in pb_field_value:
                yield (entry_pb.key, entry_pb.value)
    elif Builder.is_map_entry(entry_message_desc):
        for key, value in sorted(pb_field_value.items()):
            yield (key, value)
    else:
        raise RuntimeError('pb_field_value is not a protobuf map')


def pb_map_to_dict(pb_field_value, pb_field_desc):
    rv = OrderedDict()
    for key, value in iter_pb_map_items(pb_field_value, pb_field_desc):
        if isinstance(value, protobufmessage.Message):
            rv[key] = pb_to_dict(value)
        else:
            rv[key] = value
    return rv


F_PREFIX = 'f_'
K_PREFIX = 'k_'
C_PREFIX = 'c_'


def pb_to_dict(message_pb):
    """
    Returns an ordered dict that can be dumped to YAML to produce a "canonical" YAML representation.

    :param message_pb: a protobuf message from infra.awacs.proto.modules_pb2
    :rtype: OrderedDict
    """
    message_desc = message_pb.DESCRIPTOR
    embedded_field_name = Builder.get_embedded_field_name(message_desc)
    if embedded_field_name:
        # embedded field names are set for "wrapper" messages, such as Port
        f_embedded_field_name = '{}{}'.format(F_PREFIX, embedded_field_name)
        embedded_field_desc = message_desc.fields_by_name.get(f_embedded_field_name)
        if (embedded_field_desc and
                embedded_field_desc.message_type and
                embedded_field_desc.message_type.name == 'Call' and
                message_pb.HasField(f_embedded_field_name)):
            call_expr = luaparser.dump_call(funcs.call_pb_to_raw_call(getattr(message_pb, f_embedded_field_name)))
            return util.FTag(call_expr)
        else:
            return getattr(message_pb, embedded_field_name)

    rv = OrderedDict()
    seen_names = set()
    for name, field_desc in six.iteritems(message_desc.fields_by_name):
        if field_desc.type == FD.TYPE_MESSAGE and field_desc.label == FD.LABEL_REPEATED:
            field_value = getattr(message_pb, name)
            if not field_value:
                continue
            # repeated message can be either a list of messages, or a map (list of entries)
            if Builder.is_map_entry(field_desc.message_type) or Builder.is_ordered_map_entry(field_desc.message_type):
                rv[name] = pb_map_to_dict(field_value, field_desc)
            else:
                rv[name] = [pb_to_dict(item) for item in field_value]
        elif field_desc.type == FD.TYPE_MESSAGE:
            if not message_pb.HasField(name):
                continue
            field_value = getattr(message_pb, name)
            if name.startswith(F_PREFIX) and field_desc.message_type.name == 'Call':
                call_expr = luaparser.dump_call(funcs.call_pb_to_raw_call(field_value))
                unprefixed_name = name[len(F_PREFIX):]
                rv[unprefixed_name] = util.FTag(call_expr)
            elif name.startswith(K_PREFIX) and field_desc.message_type.name == 'KnobRef':
                knob_id = field_value.id
                unprefixed_name = name[len(K_PREFIX):]
                rv[unprefixed_name] = util.KTag(knob_id)
            elif name.startswith(C_PREFIX) and field_desc.message_type.name == 'CertRef':
                cert_id = field_value.id
                unprefixed_name = name[len(C_PREFIX):]
                rv[unprefixed_name] = util.CTag(cert_id)
            else:
                field_awacs_options = Builder.get_awacs_field_options(field_desc)
                dumped_value = pb_to_dict(field_value)
                if field_awacs_options and field_awacs_options.squashed:
                    rv.update(dumped_value)
                else:
                    rv[name] = dumped_value
        else:
            field_value = getattr(message_pb, name)
            if not field_value:
                continue
            if field_desc.label == FD.LABEL_REPEATED:
                # cast protobuf's scalar containers to a regular list so that pyyaml can process it
                field_value = list(field_value)
            elif field_desc.enum_type:
                field_value = field_desc.enum_type.values_by_number[field_value].name
            rv[name] = field_value
        seen_names.add(name)

    return rv


def dump(pb):
    return yaml.dump(pb_to_dict(pb), default_flow_style=False, Dumper=util.AwacsYamlDumper)
