from typing import Any, Callable, Type, Union

import ujson
from marshmallow import Schema, ValidationError

from .exceptions import APIException, InvalidRequest, InvalidResponse


class BaseSchema(Schema):
    class Meta:
        json_module = ujson

    @staticmethod
    def _adapt_run(func: Callable, args: Any, kwargs: Any) -> tuple:
        try:
            result = func(*args, **kwargs)
        except ValidationError as exc:
            return exc.valid_data, exc.normalized_messages()

        if isinstance(result, tuple):
            return result
        return result, None

    def load_raising(self, *args: Any,
                     on_error: Union[Type[APIException], APIException] = InvalidRequest,
                     **kwargs: Any) -> Any:
        data, errors = self._adapt_run(super().load, args, kwargs)
        if errors:
            raise on_error(params={'reason': errors})  # type: ignore

        return data

    def loads_raising(self, *args: Any,
                      on_error: Union[Type[APIException], APIException] = InvalidRequest,
                      **kwargs: Any) -> Any:
        data, errors = self._adapt_run(super().loads, args, kwargs)
        if errors:
            raise on_error(params={'reason': errors})  # type: ignore

        return data

    def dump_raising(self, *args: Any,
                     on_error: Union[Type[APIException], APIException] = InvalidResponse,
                     **kwargs: Any) -> Any:
        data, errors = self._adapt_run(super().dump, args, kwargs)
        if errors:
            raise on_error(params={'reason': errors})  # type: ignore

        return data

    def dumps_raising(self, *args: Any,
                      on_error: Union[Type[APIException], APIException] = InvalidResponse,
                      **kwargs: Any) -> Any:
        data, errors = self._adapt_run(super().dumps, args, kwargs)
        if errors:
            raise on_error(params={'reason': errors})  # type: ignore

        return data

    def load(self, *args: Any, **kwargs: Any) -> Any:
        return self._adapt_run(super().load, args, kwargs)

    def dump(self, *args: Any, **kwargs: Any) -> Any:
        return self._adapt_run(super().dump, args, kwargs)

    def loads(self, *args: Any, **kwargs: Any) -> Any:
        return self._adapt_run(super().loads, args, kwargs)

    def dumps(self, *args: Any, **kwargs: Any) -> Any:
        return self._adapt_run(super().dumps, args, kwargs)
