import typing
import dataclasses
from functools import cached_property
from datetime import datetime
from decimal import Decimal
from enum import Enum
from google.protobuf.message import Message
from google.protobuf.timestamp_pb2 import Timestamp

__all__ = [
    'BaseStruct',
    'timestamp_or_none',
    'decimal_to_string',
    'decimal_percent_to_string',
    'safe_get_nullable_field'
]


@dataclasses.dataclass
class BaseStruct:
    @classmethod
    def from_proto(cls, message: Message) -> 'BaseStruct':
        pass

    def to_proto(self) -> Message:
        pass

    @staticmethod
    def _is_optional(field):
        return typing.get_origin(field) is typing.Union and type(None) in typing.get_args(field)

    @staticmethod
    def _check_enum(tp):
        try:
            return issubclass(tp, Enum)
        except TypeError:
            return False

    @staticmethod
    def _get_type(field):
        if BaseStruct._is_optional(field.type):
            type_to_check = next((t for t in typing.get_args(field.type) if not isinstance(t, type(None))), None)
        else:
            type_to_check = field.type
        return type_to_check

    @cached_property
    def enum_fields_list(self):
        enum_fields_list = set()
        for field in dataclasses.fields(self):
            check = self._check_enum(self._get_type(field))
            if check:
                enum_fields_list.add(field)
        return enum_fields_list

    def __post_init__(self):
        for field in self.enum_fields_list:
            value = getattr(self, field.name)
            field_type = self._get_type(field)
            if value is not None and not isinstance(value, field_type):
                setattr(self, field.name, field_type(value))
        self.context = {}


def timestamp_or_none(value: datetime) -> typing.Optional[Timestamp]:
    if value:
        res = Timestamp()
        res.FromDatetime(value)
    else:
        res = None
    return res


def decimal_to_string(value: typing.Union[Decimal, float], fmt="{:f}") -> typing.Optional[str]:
    if isinstance(value, float):
        value = Decimal(value)
    return fmt.format(value.normalize()) if value is not None else None


def decimal_percent_to_string(value: typing.Union[Decimal]) -> typing.Optional[str]:
    if isinstance(value, float):
        value = Decimal(value)
    return "{:.3f}".format(value.normalize()) if value is not None else None


def safe_get_nullable_field(message, key, coerce=lambda x: x):
    att = getattr(message, key, None) if message.HasField(key) else None
    return coerce(getattr(att, 'value', att)) if att is not None else None


def safe_wrapper_coercion(value, wrapper_cls):
    return wrapper_cls(value=value) if value is not None else None
