import base64
import binascii
import json
from typing import Any, ClassVar, Dict, NoReturn, Type

from marshmallow import ValidationError, fields, post_dump, post_load, pre_load, validate

from mail.payments.payments.api.schemas.base import BaseSchema
from mail.payments.payments.core.entities.keyset import BaseKeysetEntity, KeysetEntry
from mail.payments.payments.core.exceptions import KeysetInvalidError


class BaseKeysetEntrySchema(BaseSchema):
    barrier: fields.Field = fields.Raw(required=True)
    order = fields.String(required=True, validate=validate.OneOf(['asc', 'desc']))


class KeysetEntryDateTimeSchema(BaseKeysetEntrySchema):
    barrier = fields.DateTime(required=True)


class KeysetEntryIntegerSchema(BaseKeysetEntrySchema):
    barrier = fields.Integer(required=True)


class KeysetEntryDecimalSchema(BaseKeysetEntrySchema):
    barrier = fields.Decimal(required=True, as_string=True)


class BaseKeysetEntitySchema(BaseSchema):
    KEYSET_ENTITY_CLS: ClassVar[Type[BaseKeysetEntity]]

    sort_order = fields.List(fields.String(), required=True)

    def _raise_invalid_keyset(self) -> NoReturn:
        raise ValidationError('Invalid keyset')

    @pre_load
    def parse_json(self, data: str) -> Dict[str, Any]:
        try:
            return json.loads(base64.b64decode(data).decode('utf-8'))
        except (json.JSONDecodeError, binascii.Error):
            self._raise_invalid_keyset()

    @post_load
    def to_dataclass(self, data: Dict[str, Any], *args: Any, **kwargs: Any) -> BaseKeysetEntity:
        for key in data:
            if key != 'sort_order':
                data[key] = KeysetEntry(**data[key])
        try:
            return self.KEYSET_ENTITY_CLS(**data)
        except KeysetInvalidError:
            self._raise_invalid_keyset()

    @post_dump
    def to_json(self, data: Dict[str, Any]) -> str:
        return base64.b64encode(json.dumps(data).encode('utf-8')).decode('ascii')
