# coding: utf-8
import enum
import yaml
from google.protobuf import descriptor

from . import util, errors


FD = descriptor.FieldDescriptor


class NonTerminals(enum.IntEnum):
    STREAM = 1
    MESSAGE = 2
    MESSAGE_FIELDS = 3
    SEQUENCE = 4
    SEQUENCE_ITEMS = 5
    MAPPING = 6
    MAPPING_ITEMS = 7


STREAM_RULE = [yaml.StreamStartEvent, yaml.DocumentStartEvent,
               '_on_document_start', NonTerminals.MESSAGE,
               yaml.DocumentEndEvent, yaml.StreamEndEvent]

MESSAGE_RULE = [yaml.MappingStartEvent, NonTerminals.MESSAGE_FIELDS]
FIELDS_RULE_1 = [yaml.ScalarEvent, '_on_message_field_name', NonTerminals.MESSAGE_FIELDS]
FIELDS_RULE_2 = [yaml.MappingEndEvent, '_on_message_end']

MAPPING_RULE = [yaml.MappingStartEvent, NonTerminals.MAPPING_ITEMS]
ENTRIES_RULE_1 = [yaml.ScalarEvent, '_on_mapping_key', NonTerminals.MAPPING_ITEMS]
ENTRIES_RULE_2 = [yaml.MappingEndEvent, '_on_mapping_end']

SEQUENCE_RULE = [yaml.SequenceStartEvent, NonTerminals.SEQUENCE_ITEMS]
ITEMS_RULE_1 = ['_on_sequence_item', NonTerminals.SEQUENCE_ITEMS]
ITEMS_RULE_2 = [yaml.SequenceEndEvent, '_on_sequence_end']

SCALAR_RULE_1 = [yaml.ScalarEvent, '_on_scalar_sequence_item']
SCALAR_RULE_2 = [yaml.ScalarEvent, '_on_scalar_message_field_value']
SCALAR_RULE_3 = [yaml.ScalarEvent, '_on_scalar_mapping_value']

SELECT = {
    NonTerminals.STREAM: {
        yaml.StreamStartEvent: STREAM_RULE,
    },
    NonTerminals.MESSAGE: {
        yaml.MappingStartEvent: MESSAGE_RULE,
    },
    NonTerminals.MESSAGE_FIELDS: {
        yaml.ScalarEvent: FIELDS_RULE_1,
        yaml.MappingEndEvent: FIELDS_RULE_2,
    },
    NonTerminals.SEQUENCE: {
        yaml.SequenceStartEvent: SEQUENCE_RULE,
    },
    NonTerminals.SEQUENCE_ITEMS: {
        yaml.ScalarEvent: ITEMS_RULE_1,
        yaml.MappingStartEvent: ITEMS_RULE_1,
        yaml.SequenceEndEvent: ITEMS_RULE_2,
    },
    NonTerminals.MAPPING: {
        yaml.MappingStartEvent: MAPPING_RULE,
    },
    NonTerminals.MAPPING_ITEMS: {
        yaml.ScalarEvent: ENTRIES_RULE_1,
        yaml.MappingEndEvent: ENTRIES_RULE_2,
    },
}


class ProtobufBuilder(object):
    __slots__ = ('_pb_cls', '_result_pb', '_stack', '_pb_stack', '_seen_fields_stack')

    def __init__(self, pb_cls):
        self._pb_cls = pb_cls

        self._result_pb = None
        self._stack = [NonTerminals.STREAM]
        self._pb_stack = []
        self._seen_fields_stack = []

    def get_result(self):
        return self._result_pb

    @classmethod
    def _get_human_readable_name_by_desc(cls, desc):
        if isinstance(desc, descriptor.Descriptor):
            if cls.is_map_entry(desc) or cls.is_ordered_map_entry(desc):
                rv = 'mapping'
            else:
                rv = desc.name
        else:
            if desc.type in util._INT32_TYPES:
                rv = 'int32'
            elif desc.type in util._INT64_TYPES:
                rv = 'int64'
            elif desc.type in util._UINT32_TYPES:
                rv = 'uint32'
            elif desc.type in util._UINT64_TYPES:
                rv = 'uint64'
            elif desc.type in util._FLOAT_TYPES:
                rv = 'float'
            elif desc.type == FD.TYPE_BOOL:
                rv = 'bool'
            elif desc.type == FD.TYPE_STRING:
                rv = 'string'
            elif desc.type == FD.TYPE_BYTES:
                rv = 'sequence of bytes'
            elif desc.type == FD.TYPE_ENUM:
                rv = 'one of the following values: "{}"'.format('", "'.join(desc.enum_type.values_by_name.keys()))
            elif desc.type == FD.TYPE_MESSAGE:
                rv = cls._get_human_readable_name_by_desc(desc.message_type)
            else:
                raise errors.InternalError('unknown field type {}'.format(desc.type))

            if desc.label == FD.LABEL_REPEATED:
                rv = 'sequence of ' + rv
        return rv

    @classmethod
    def _get_human_readable_name_by_event(cls, event):
        if isinstance(event, yaml.DocumentStartEvent):
            rv = 'document'
        elif isinstance(event, yaml.MappingStartEvent):
            rv = 'mapping'
        elif isinstance(event, yaml.SequenceStartEvent):
            rv = 'sequence'
        elif isinstance(event, yaml.ScalarEvent):
            rv = 'scalar'
        else:
            raise errors.InternalError('unexpected event: {}'.format(event.__class__.__name__))
        return rv

    def _reject(self, event):
        """
        :type event: yaml.Event
        :raises: errors.SchemaError
        """
        if self._pb_stack:
            _, pb_desc = self._pb_stack[-1]
            expected = self._get_human_readable_name_by_desc(pb_desc)
        else:
            expected = 'no further data'
        got = self._get_human_readable_name_by_event(event)

        msg = '{} is not accepted here'.format(got)
        if expected:
            msg += ', {} expected'.format(expected)
        raise errors.SchemaError(msg)

    def _peek(self):
        return self._stack and self._stack[-1] or None

    @staticmethod
    def _is_non_terminal(item):
        return item in NonTerminals

    @staticmethod
    def _is_terminal(item):
        return isinstance(item, type)

    @staticmethod
    def _is_action(item):
        return isinstance(item, str)

    def _expand_non_terminal(self, event):
        top = self._peek()
        if not self._is_non_terminal(top):
            return
        event_cls = type(event)
        if event_cls not in SELECT[top]:
            self._reject(event)
        rules = SELECT[top]
        rule = rules[event_cls]
        self._stack.pop()
        self._stack.extend(reversed(rule))

    def _run_action(self, event):
        top = self._peek()
        if not self._is_action(top):
            return
        self._stack.pop()
        action = getattr(self, top)
        action(event)

    @staticmethod
    def _handle_aliases(event):
        if isinstance(event, yaml.AliasEvent) or getattr(event, 'anchor', None) is not None:
            raise errors.YamlSyntaxError('anchors and aliases are not supported')

    def process(self, event):
        """
        :type event: yaml.Event
        """
        self._handle_aliases(event)

        while not self._is_terminal(self._peek()):
            self._expand_non_terminal(event)
            self._run_action(event)

        expected_event_cls = self._stack.pop()
        if not isinstance(event, expected_event_cls):
            self._reject(event)

        while self._is_action(self._peek()):
            self._run_action(event)

    @classmethod
    def is_map_entry(cls, message_type):
        return message_type.has_options and message_type.GetOptions().map_entry

    @classmethod
    def is_ordered_map_entry(cls, field_desc):
        return False

    @classmethod
    def get_field_value(cls, message_pb, field_name):
        fields = message_pb.DESCRIPTOR.fields_by_name
        if field_name in fields:
            return getattr(message_pb, field_name)

    @classmethod
    def get_field_desc(cls, message_pb, field_name):
        fields = message_pb.DESCRIPTOR.fields_by_name
        if field_name in fields:
            return fields[field_name]

    @classmethod
    def get_which_oneof(cls, message_pb, oneof_name):
        return message_pb.WhichOneof(oneof_name)

    # ACTIONS:

    def _on_document_start(self, event):
        """
        :type event: yaml.DocumentStartEvent
        """
        self._result_pb = message_pb = self._pb_cls()
        self._pb_stack = [(message_pb, message_pb.DESCRIPTOR)]
        self._seen_fields_stack.append(set())

    def _handle_message(self, message_pb):
        message_pb.SetInParent()
        self._stack.append(NonTerminals.MESSAGE)
        self._pb_stack.append((message_pb, message_pb.DESCRIPTOR))
        self._seen_fields_stack.append(set())

    def _handle_mapping(self, value, desc):
        self._stack.append(NonTerminals.MAPPING)
        self._pb_stack.append((value, desc))
        self._seen_fields_stack.append(set())

    def _handle_message_field(self, field_name, field_value, field_desc):
        if field_desc.label == FD.LABEL_REPEATED and field_desc.type == FD.TYPE_MESSAGE:
            message_desc = field_desc.message_type
            if self.is_map_entry(message_desc):
                self._handle_mapping(field_value, message_desc)
            elif self.is_ordered_map_entry(message_desc):
                self._handle_mapping(util.OrderedMapProxy(message_desc, field_value), message_desc)
            else:
                self._stack.append(NonTerminals.SEQUENCE)
                self._pb_stack.append((field_value, field_desc))

        elif field_desc.label == FD.LABEL_REPEATED:
            self._stack.append(NonTerminals.SEQUENCE)
            self._pb_stack.append((field_value, field_desc))

        elif field_desc.type == FD.TYPE_MESSAGE:
            self._handle_message(field_value)

        else:
            self._pb_stack.append((field_name, field_desc))
            self._stack.extend(reversed(SCALAR_RULE_2))

    def _on_message_field_name(self, event):
        """
        :type event: yaml.ScalarEvent
        """
        message_pb, message_desc = self._pb_stack[-1]
        seen_fields = self._seen_fields_stack[-1]

        field_name = event.value
        field_desc = self.get_field_desc(message_pb, field_name)
        if not field_desc:
            raise errors.SchemaError('{} does not have field "{}"'.format(message_desc.name, field_name))

        if field_name in seen_fields:
            raise errors.SchemaError('duplicate field "{}"'.format(field_name))
        else:
            seen_fields.add(field_name)

        if field_desc.containing_oneof:
            conflicting_oneof_key = self.get_which_oneof(message_pb, field_desc.containing_oneof.name)
            if conflicting_oneof_key is not None:
                raise errors.SchemaError(
                    '"{}" and "{}" cannot both be present in {}'.format(
                        field_name, conflicting_oneof_key, message_desc.name))

        field_value = self.get_field_value(message_pb, field_name)
        self._handle_message_field(field_name, field_value, field_desc)

    def _on_mapping_key(self, event):
        """
        :type event: yaml.ScalarEvent
        """
        mapping_pb, mapping_desc = self._pb_stack[-1]
        seen_fields = self._seen_fields_stack[-1]

        key_desc = mapping_desc.fields_by_name['key']
        try:
            key = util.get_scalar_field_value(key_desc, event.value)
        except ValueError as e:
            raise errors.SchemaError('failed to parse key: {}'.format(e))

        if key in seen_fields:
            raise errors.SchemaError('duplicate key "{}"'.format(key))
        else:
            seen_fields.add(key)

        value_desc = mapping_desc.fields_by_name['value']
        if value_desc.type == FD.TYPE_MESSAGE:
            message_pb = mapping_pb[key]
            self._handle_message(message_pb)
        else:
            self._stack.extend(reversed(SCALAR_RULE_3))
            self._pb_stack.append((key, value_desc))

    def _on_sequence_item(self, event):
        """
        :type event: yaml.ScalarEvent
        """
        sequence_pb, sequence_desc = self._pb_stack[-1]
        assert sequence_desc.label == FD.LABEL_REPEATED

        if sequence_desc.type == FD.TYPE_MESSAGE:
            message_pb = sequence_pb.add()
            self._handle_message(message_pb)
        else:
            self._stack.extend(reversed(SCALAR_RULE_1))

    # scalar handlers:

    def _on_scalar_sequence_item(self, event):
        """
        :type event: yaml.ScalarEvent
        """
        sequence_pb, sequence_desc = self._pb_stack[-1]
        assert sequence_desc.label == FD.LABEL_REPEATED

        try:
            value = util.get_scalar_field_value(sequence_desc, event.value)
        except ValueError as e:
            raise errors.SchemaError('failed to parse sequence item: {}'.format(e))
        else:
            sequence_pb.append(value)

    def _on_scalar_mapping_value(self, event):
        """
        :type event: yaml.ScalarEvent
        """
        key, key_desc = self._pb_stack.pop()
        mapping_pb, mapping_desc = self._pb_stack[-1]
        assert self.is_map_entry(mapping_desc) or self.is_ordered_map_entry(mapping_desc)

        try:
            value = util.get_scalar_field_value(key_desc, event.value)
        except ValueError as e:
            raise errors.SchemaError('failed to parse key: {}'.format(e))
        else:
            mapping_pb[key] = value

    def _on_scalar_message_field_value(self, event):
        """
        :type event: yaml.ScalarEvent
        """
        field_name, field_desc = self._pb_stack.pop()  # type: str, descriptor.FieldDescriptor
        message_pb, message_desc = self._pb_stack[-1]
        assert isinstance(message_desc, descriptor.Descriptor)

        try:
            value = util.get_scalar_field_value(field_desc, event.value)
        except ValueError as e:
            raise errors.SchemaError('failed to parse "{}": {}'.format(field_name, e))
        else:
            if field_desc.containing_oneof and value == field_desc.default_value:
                # Please see
                # https://github.com/google/protobuf/blob/0400cca3236de1ca303af38bf81eab332d042b7c/python/google/protobuf/internal/message_test.py#L662-L676
                # https://github.com/google/protobuf/issues/491
                # We don't want to consider scalar fields equal to their defaults as "set" due to this quirk.
                # So let's just not set them and pass:
                pass
            else:
                setattr(message_pb, field_name, value)

    def _on_mapping_end(self, event):
        """
        :type event: yaml.MappingEndEvent
        """
        _, mapping_desc = self._pb_stack.pop()
        self._seen_fields_stack.pop()
        assert self.is_map_entry(mapping_desc) or self.is_ordered_map_entry(mapping_desc)

    def _on_message_end(self, event):
        """
        :type event: yaml.MappingEndEvent
        """
        _, message_desc = self._pb_stack.pop()
        self._seen_fields_stack.pop()
        assert isinstance(message_desc, descriptor.Descriptor)

    def _on_sequence_end(self, event):
        """
        :type event: yaml.SequenceEndEvent
        """
        _, sequence_desc = self._pb_stack.pop()
        assert sequence_desc.label == FD.LABEL_REPEATED
