from collections.abc import Mapping

from google.protobuf.descriptor import FieldDescriptor
from google.protobuf.message import DecodeError, EncodeError, Message
from google.protobuf.pyext._message import (
    MessageMapContainer,
    RepeatedCompositeContainer,
    RepeatedScalarContainer,
    ScalarMapContainer,
)
from marshmallow import MarshalResult, Schema, ValidationError
from marshmallow.schema import SchemaMeta


class ProtobufSchemaMeta(SchemaMeta):
    def __new__(mcs, name, bases, attrs):
        meta = attrs.get("Meta")
        attrs["_pb_message_class"] = getattr(meta, "pb_message_class", None)
        return super().__new__(mcs, name, bases, attrs)


class StrictSchema(Schema):
    def __init__(self, *args, **kwargs):
        kwargs["strict"] = True
        super().__init__(*args, **kwargs)


class ProtobufSchema(StrictSchema, metaclass=ProtobufSchemaMeta):
    def from_bytes(self, encoded_message):
        try:
            message = self._pb_message_class.FromString(encoded_message)
        except DecodeError:
            raise ValidationError(
                "Failed to decode proto data as {}".format(
                    self._pb_message_class.DESCRIPTOR.full_name
                )
            )

        return self.load(message).data

    def load(self, message, **kwargs):
        if not self.many:
            result = self._decode_message(message)
        else:
            result = list(map(self._decode_message, message))

        return super().load(result)

    def _decode_message(self, message):
        if not isinstance(message, self._pb_message_class):
            raise ValidationError(
                'Schema expected "{}", not "{}"'.format(
                    self._pb_message_class, type(message)
                )
            )

        # In case of partial data
        if not message.IsInitialized():
            raise ValidationError(
                "Failed to decode all proto data as valid {}".format(
                    self._pb_message_class.DESCRIPTOR.full_name
                )
            )

        return self.deserialize_as_dict(message)

    def to_bytes(self, data):
        message = self.dump(data).data

        try:
            return message.SerializeToString()
        except EncodeError:
            raise ValidationError(
                "Failed to encode data as {}".format(
                    self._pb_message_class.DESCRIPTOR.full_name
                )
            )

    def dump(self, data, **kwargs):
        dumped_data = super().dump(data).data
        if not self.many:
            result = self.serialize_from_dict(dumped_data)
        else:
            result = list(map(self.serialize_from_dict, dumped_data))

        return MarshalResult(result, {})

    @classmethod
    def deserialize_as_dict(cls, proto_message):
        dict_data = {}
        for field in cls._iter_message_fields(proto_message):
            if not cls._message_has_field(proto_message, field):
                continue

            dict_data[field.name] = cls._get_proto_field(proto_message, field.name)

        return dict_data

    @classmethod
    def serialize_from_dict(cls, data):
        message = cls._pb_message_class()
        for field in cls._iter_message_fields(message):
            proto_value = data.get(field.name)
            if proto_value is None:
                if field.label != FieldDescriptor.LABEL_REQUIRED:
                    continue
                else:
                    raise ValidationError(
                        'Required field "{}" not found in input'.format(field.name)
                    )

            cls._set_proto_field(message, field.name, proto_value)

        return message

    @classmethod
    def _get_proto_field(cls, message, field_name):
        field = message.DESCRIPTOR.fields_by_name[field_name]
        proto_value = getattr(message, field_name)

        if field.label == FieldDescriptor.LABEL_REPEATED:
            if isinstance(proto_value, (ScalarMapContainer, MessageMapContainer)):
                return dict(proto_value)
            else:
                return list(proto_value)

        return proto_value

    @classmethod
    def _set_proto_field(cls, message, field_name, value):
        message_field = getattr(message, field_name)
        if isinstance(value, Message):
            message_field.CopyFrom(value)
        elif isinstance(message_field, ScalarMapContainer):
            message_field.update(value)
        elif isinstance(message_field, MessageMapContainer):
            for key, val in value.items():
                message_field[key].CopyFrom(val)
        elif isinstance(value, Mapping):
            for subfield_name, subfield_value in value.items():
                cls._set_proto_field(message_field, subfield_name, subfield_value)
        elif isinstance(
            value, (list, tuple, RepeatedScalarContainer, RepeatedCompositeContainer)
        ):
            message_field.extend(value)
        else:
            setattr(message, field_name, value)

    @staticmethod
    def _iter_message_fields(message):
        yield from message.DESCRIPTOR.fields

    @staticmethod
    def _message_has_field(message, field):
        return (
            message.DESCRIPTOR.syntax != "proto2"
            or field.label == FieldDescriptor.LABEL_REPEATED
            or message.HasField(field.name)
        )


__all__ = ["ProtobufSchema"]
