import base64
from datetime import datetime, timezone
from typing import Any, Dict, Optional, Union

from marshmallow import fields, validate
from marshmallow_enum import EnumField

from sendr_utils.schemas.validators import DecimalExponentValidator


def _add_validators(
    *extra_validators: validate.Validator,
    kwargs: Dict[str, Any],
) -> Dict[str, Any]:
    if not extra_validators:
        return kwargs

    validators = kwargs.get('validate', [])
    assert isinstance(validators, (list, validate.Validator))
    if not isinstance(validators, list):
        validators = [validators]

    validators.extend(extra_validators)
    kwargs['validate'] = validators

    return kwargs


class EpochTimestamp(fields.Field):
    default_error_messages = {
        'invalid': 'Not a valid number.',
        'too_large': 'Number too large.',
    }

    def __init__(self, as_string: bool = False, in_milliseconds: bool = False, **kwargs: Any):
        super().__init__(**kwargs)
        self.as_string = as_string
        self.in_milliseconds = in_milliseconds

    def _serialize(self, value: Optional[datetime], attr: Any, obj: Any) -> Optional[Union[int, str]]:
        if value is None:
            return None

        timestamp = value.timestamp()
        if self.in_milliseconds:
            timestamp *= 1000
        timestamp = int(timestamp)
        if self.as_string:
            timestamp = str(timestamp)  # type: ignore
        return timestamp

    def _validated(self, value: Any) -> float:  # type: ignore
        """Format the value or raise a :exc:`ValidationError` if an error occurs."""
        try:
            return float(value)
        except (TypeError, ValueError):
            self.fail('invalid')
        except OverflowError:
            self.fail('too_large')

    def _deserialize(self, value: Optional[Union[int, str, float]], *args: Any, **kwargs: Any) -> Optional[datetime]:
        if value is None:
            return None

        timestamp = self._validated(value)
        if self.in_milliseconds:
            timestamp /= 1000

        return datetime.fromtimestamp(timestamp, tz=timezone.utc)


class CurrencyField(fields.String):
    def __init__(self, *args, **kwargs):
        kwargs.setdefault('description', 'ISO 4217 alpha code. E.g. RUB, USD, XTS')

        kwargs = _add_validators(
            validate.Regexp('^[A-Z]{3}$', error='Not a valid ISO 4217 alpha code.'),
            kwargs=kwargs,
        )

        super().__init__(*args, **kwargs)


class AmountField(fields.Decimal):
    def __init__(self, *args: Any, format_enforced: bool = False, as_string: bool = True, **kwargs: Any):
        description = 'Не должно содержать больше двух знаков после запятой.\nНапример: 1.12, 5.1, 10, 11.00 .'
        kwargs.setdefault('description', description)

        if format_enforced:
            kwargs = _add_validators(
                DecimalExponentValidator(exponent=2),
                kwargs=kwargs,
            )

        super().__init__(*args, as_string=as_string, **kwargs)

    def _deserialize(self, value, *args, **kwargs):
        return super()._deserialize(value, *args, **kwargs)


class CIEnumField(EnumField):
    def _deserialize_by_value(self, value, attr, data):
        for each in self.enum:
            if each.value.casefold() == value.casefold():
                return each

        self.fail('by_value', input=value, value=value)


class BytesField(fields.Field):
    default_error_messages = {
        'invalid': 'Not a valid base64 string.',
    }

    def _serialize(self, value: Optional[bytes], attr: Any, obj: Any) -> Optional[str]:
        if value is None:
            return None

        return base64.b64encode(value).decode('ascii')

    def _deserialize(self, value: Optional[str], *args: Any, **kwargs: Any) -> Optional[bytes]:
        if value is None:
            return None

        try:
            return base64.b64decode(value)
        except ValueError:
            raise self.fail('invalid')
