from functools import wraps
from typing import Optional, Type

from marshmallow import ValidationError

from maps_adv.common.protomallow import ProtobufSchema


class OutputValidationError(Exception):
    def __init__(self, inner_exception: ValidationError):
        self.inner_exception = inner_exception


def with_schemas(
    input_schema: Optional[Type[ProtobufSchema]] = None,
    output_schema: Optional[Type[ProtobufSchema]] = None,
):
    """Decorator for serializing input arguments and deserializing returned value."""

    def decorator(func):
        @wraps(func)
        async def wrapper(*args, data: Optional[bytes] = None, **kwargs):
            if input_schema is not None:
                if data is None:
                    raise RuntimeError("Provide protobuf message input")
                dumped = input_schema().from_bytes(data)
                kwargs.update(dumped)
            got = await func(*args, **kwargs)
            if output_schema is not None:
                try:
                    return output_schema().to_bytes(got)
                except ValidationError as exc:
                    raise OutputValidationError(exc)

            return got

        return wrapper

    return decorator
