import json
from functools import wraps
from typing import Optional

from marshmallow import ValidationError, fields, post_load, pre_dump, validate

from maps_adv.common.protomallow import PbDateTimeField, ProtobufSchema
from maps_adv.warden.proto import tasks_pb2


def with_proto_schemas(
    input_schema: Optional[ProtobufSchema] = None,
    output_schema: Optional[ProtobufSchema] = None,
):
    """Decorator for serializing input arguments and deserializing returned value."""

    def decorator(func):
        @wraps(func)
        async def wrapper(s, data=None, **kwargs):
            if input_schema is not None:
                if data is None:
                    raise RuntimeError(
                        "This api_provider requires protobuf message input"
                    )
                dumped = input_schema().from_bytes(data)
                kwargs.update(dumped)
            got = await func(s, **kwargs)
            if output_schema is not None:
                return output_schema().to_bytes(got)

        return wrapper

    return decorator


class JsonMetadataMixin:
    @post_load
    def metadata_load(self, data):
        if data.get("metadata") is not None:
            try:
                data["metadata"] = json.loads(data["metadata"])
            except json.JSONDecodeError:
                raise ValidationError(
                    "Invalid metadata field", field_names=["metadata"]
                )

    @pre_dump
    def metadata_dump(self, data):
        if data.get("metadata") is not None:
            data["metadata"] = json.dumps(data["metadata"])


class CreateTaskInputProtoSchema(ProtobufSchema, JsonMetadataMixin):
    class Meta:
        pb_message_class = tasks_pb2.CreateTaskInput

    type_name = fields.String(
        validate=validate.Length(min=1, error="Value should not be empty.")
    )
    executor_id = fields.String(
        validate=validate.Length(min=1, error="Value should not be empty.")
    )
    metadata = fields.String(required=False, allow_none=True)


class TaskDetailsProtoSchema(ProtobufSchema, JsonMetadataMixin):
    class Meta:
        pb_message_class = tasks_pb2.TaskDetails

    task_id = fields.Integer()
    status = fields.String()
    time_limit = fields.Integer()
    metadata = fields.String(required=False, allow_none=True)


class UpdateTaskInputProtoSchema(ProtobufSchema, JsonMetadataMixin):
    class Meta:
        pb_message_class = tasks_pb2.UpdateTaskInput

    type_name = fields.String(
        validate=validate.Length(min=1, error="Value should not be empty.")
    )
    task_id = fields.Integer()
    executor_id = fields.String(
        validate=validate.Length(min=1, error="Value should not be empty.")
    )
    status = fields.String(
        validate=validate.Length(min=1, error="Value should not be empty.")
    )
    metadata = fields.String(allow_none=True)


class UpdateTaskOutputProtoSchema(ProtobufSchema):
    class Meta:
        pb_message_class = tasks_pb2.UpdateTaskOutput

    scheduled_time = PbDateTimeField(required=False)
