from functools import wraps
from typing import Optional

from marshmallow import Schema

from maps_adv.common.protomallow import ProtobufSchema

__all__ = ["with_schemas", "StrictSchema"]


def with_schemas(
    input_schema: Optional[ProtobufSchema] = None,
    output_schema: Optional[ProtobufSchema] = None,
):
    """Decorator for serializing input arguments and deserializing returned value."""

    def decorator(func):
        @wraps(func)
        async def wrapper(s, data=None, **kwargs):
            if input_schema is not None:
                if data is None:
                    raise RuntimeError(
                        "This api_provider requires protobuf message input"
                    )
                dumped = input_schema().from_bytes(data)
                kwargs.update(dumped)
            got = await func(s, **kwargs)
            if output_schema is not None:
                return output_schema().to_bytes(got)

            return got

        return wrapper

    return decorator


class StrictSchema(Schema):
    def __init__(self, *args, **kwargs):
        kwargs["strict"] = True
        super().__init__(*args, **kwargs)
