import datetime
import decimal
from decimal import Decimal
from enum import Enum
from typing import Iterable, Type, TypeVar
from math import floor

from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper
from google.protobuf.message import Message
from marshmallow import fields
from marshmallow.exceptions import ValidationError
from marshmallow_enum import EnumField

from maps_adv.common.helpers import Converter


class PbFixedDecimalField(fields.Field):
    """
    Field for working with protobuf integers that must be treated as decimals
    with fixed point position.
    @:param int places number of digits after decimal point, must be positive
    """

    def __init__(self, places: int, quantize: bool = False, **kwargs):
        if not isinstance(places, int):
            raise TypeError('"places" must be int, not {}'.format(type(places)))
        if places <= 0:
            raise ValueError('"places" must be positive')

        self._places = places
        self._coefficient = 10**places

        if quantize:
            self._quantize_decimal = Decimal((0, (1,), -places))
        else:
            self._quantize_decimal = None

        super().__init__(**kwargs)

    def _serialize(self, value, attr, obj):
        if value is None:
            return None

        if not isinstance(value, Decimal):
            raise ValidationError("decimal required (got {})".format(type(value)))

        if self._quantize_decimal is not None:
            value = value.quantize(self._quantize_decimal)
        else:
            fraction_digits = str(value % 1)[2:].rstrip("0")
            if len(fraction_digits) > self._places:
                raise ValidationError(
                    "Value has {} fraction digits, "
                    "but only {} can't be serialized".format(
                        len(fraction_digits), self._places
                    )
                )

        return int(value * self._coefficient)

    def _deserialize(self, value, attr, data):
        if not isinstance(value, int):
            raise ValidationError("int required (got {})".format(type(value)))

        return Decimal(value) / self._coefficient


class PbFixedDecimalDictField(PbFixedDecimalField):
    """
    Field for working with protobuf dicts containing one integer
    that must be treated as decimals with fixed point position.
    @:param int places number of digits after decimal point, must be positive
    @:param str field protobuf field must be a SubMessage with one integer field.
    """

    def __init__(self, places: int, field: str, **kwargs):
        super().__init__(places, **kwargs)
        if not isinstance(field, str):
            raise TypeError('"field" must be str, not {}'.format(type(field)))

        self._field = field

    def _serialize(self, value, attr, obj):
        if value is None:
            return None

        result = super()._serialize(value, attr, obj)
        return {self._field: result}

    def _deserialize(self, value, attr, data):
        if not isinstance(value, Message):
            raise ValidationError(
                "protobuf message required (got {})".format(type(value))
            )

        value = getattr(value, self._field)
        return super()._deserialize(value, attr, data)


EnumType = TypeVar("EnumType", bound=Enum)
ProtoEnumType = TypeVar("ProtoEnumType", bound=EnumTypeWrapper)


class PbEnumField(EnumField):
    """
    Field for converting values between internal enum and protobuf enum.
    """

    def __init__(
        self,
        enum: Type[EnumType],
        *args,
        pb_enum: Type[ProtoEnumType],
        values_map: Iterable,
        **kwargs,
    ):
        super().__init__(enum, *args, **kwargs)

        if not isinstance(pb_enum, EnumTypeWrapper):
            raise TypeError(
                '"pb_enum" must be a protobuf-compiled enum class, not {}'.format(
                    type(pb_enum)
                )
            )

        self._pb_enum = pb_enum
        self._enum_converter = Converter(values_map)

    def _serialize(self, value, attr, obj, **kwargs):
        if value is None:
            return None

        try:
            return self._enum_converter.reversed(value)
        except KeyError:
            raise ValidationError("No matching value was passed for {}".format(value))

    def _deserialize(self, value, attr, data, **kwargs):
        if value not in self._pb_enum._enum_type.values_by_number:
            raise ValidationError("Value is not a member of proto enum")

        try:
            return self._enum_converter.forward(value)
        except KeyError:
            raise ValidationError("No matching value was passed for {}".format(value))


class PbStringEnumField(fields.Field):
    """
    Field for converting values between python string and protobuf enum.
    """

    def __init__(self, pb_enum: Type[ProtoEnumType], **kwargs):
        super().__init__(**kwargs)
        self._pb_enum = pb_enum

    def _serialize(self, value, attr, obj):
        if value is None:
            return None

        try:
            return self._pb_enum.Value(value)
        except ValueError as e:
            raise ValidationError(str(e))

    def _deserialize(self, value, attr, data):
        try:
            return self._pb_enum.Name(value)
        except (TypeError, ValueError) as e:
            raise ValidationError(str(e))


class PbDateTimeField(fields.DateTime):
    """
    Interpret integer field as timestamp. Resulting datetime is in UTC.
    """

    def __init__(self, **kwargs):
        super().__init__(format="iso", **kwargs)

    def _serialize(self, value, attr, obj):
        if value is None:
            return None

        if value.tzinfo is None:
            raise ValidationError("Value for PbDateTimeField must have tzinfo")

        seconds, micros = map(int, "{:.6f}".format(value.timestamp()).split("."))
        nanos = micros * 1000

        return {"seconds": seconds, "nanos": nanos}

    def _deserialize(self, value, attr, data):
        if not isinstance(value, Message):
            raise ValidationError(
                "protobuf message required (got {})".format(type(value))
            )

        ts = value.seconds + (value.nanos / 1000000000)
        dt_str = datetime.datetime.fromtimestamp(
            ts, tz=datetime.timezone.utc
        ).isoformat()

        return super()._deserialize(dt_str, attr, data)


class PbDateField(fields.Date):
    def _serialize(self, value, attr, obj):
        if value is None:
            return None

        if not isinstance(value, datetime.date):
            raise ValidationError(
                "datetime.date instance required (got {})".format(type(value))
            )

        if isinstance(value, datetime.datetime):
            raise ValidationError("datetime.datetime is not suitable for this field")

        return {"year": value.year, "month": value.month, "day": value.day}

    def _deserialize(self, value, attr, data):
        if not isinstance(value, Message):
            raise ValidationError(
                "protobuf message required (got {})".format(type(value))
            )

        try:
            return datetime.date(value.year, value.month, value.day)
        except ValueError:
            raise ValidationError("Invalid date")


class PbDecimalField(fields.Field):
    """
    Field for converting value between Decimal and it's string
    representation in protobuf
    """

    def __init__(self, places: int, **kwargs):
        if not isinstance(places, int):
            raise TypeError(f'"places" must be int, not {type(places)}')
        if places <= 0:
            raise ValueError('"places" must be positive')

        self._places = places
        super().__init__(**kwargs)

    def _serialize(self, value, attr, obj):
        if value is None:
            return None

        if not isinstance(value, Decimal):
            raise ValidationError(f"Decimal type is required (got {type(value)})")

        serialized_value = str(value)
        precision = self._parse_precision(serialized_value)
        if precision > self._places:
            raise ValidationError(
                f"{value} has {precision} fraction digits, "
                f"but precision limit for serialization is {self._places} digits"
            )

        return serialized_value

    def _deserialize(self, value, attr, data):
        if not isinstance(value, str):
            raise ValidationError(f"str type is required (got {type(value)})")

        precision = self._parse_precision(value)
        if precision > self._places:
            raise ValidationError(
                f"{value} has {precision} fraction digits, "
                f"but precision limit for deserialization is {self._places} digits"
            )

        try:
            return Decimal(value)
        except decimal.InvalidOperation as e:
            raise ValidationError(
                f"'{value}' value can't be converted to Decimal"
            ) from e

    @staticmethod
    def _parse_precision(value: str) -> int:
        value_parts = value.split(".")
        if len(value_parts) > 1:
            fraction_digits = value_parts[-1].rstrip("0")
        else:
            fraction_digits = ""
        return len(fraction_digits)


class PbMapField(fields.Field):
    """
    Class for protobuf map<_, _> fields.
    Supports nested messages as values.
    """

    def __init__(self, value_field=fields.Field(), **kwargs):
        self._value_field = value_field
        super().__init__(**kwargs)

    def _serialize(self, value, attr, obj):
        if value is None:
            return None

        result = {}
        for key in value:
            result[key] = self._value_field._serialize(value[key], attr, obj)
        return result

    def _deserialize(self, value, attr, data):
        if value is None:
            return None

        result = {}
        for key in value:
            result[key] = self._value_field._deserialize(value[key], attr, data)
        return result


class PbTruncatingDecimalField(fields.Field):
    """
    Field for converting value between Decimal and it's string
    representation. Truncates values to the specified places."
    """

    def __init__(self, places: int, **kwargs):
        if not isinstance(places, int):
            raise TypeError(f'"places" must be int, not {type(places)}')
        if places <= 0:
            raise ValueError('"places" must be positive')

        self._places = places
        super().__init__(self, **kwargs)

    def _serialize(self, value, attr, obj):
        if value is None:
            return None

        if not isinstance(value, Decimal):
            raise ValidationError(f"Decimal type is required (got {type(value)})")

        return str(self._truncate_decimal(value))

    def _deserialize(self, value, attr, data):
        if not isinstance(value, str):
            raise ValidationError(f"str type is required (got {type(value)})")

        try:
            return self._truncate_decimal(Decimal(value))
        except decimal.InvalidOperation as e:
            raise ValidationError(
                f"'{value}' value can't be converted to Decimal"
            ) from e

    def _truncate_decimal(self, x: Decimal) -> Decimal:
        return Decimal(floor(x * 10**self._places)) / 10**self._places
