import inspect
import re

from typing import Generic, TypeVar, Optional, List, Any, Union

from pydantic.generics import GenericModel
from pydantic import BaseConfig, BaseModel, ConstrainedStr
from pydantic.fields import ModelField
from pydantic.error_wrappers import ErrorWrapper, ValidationError
from pydantic.utils import ROOT_KEY
from pydantic.validators import (
    constr_lower,
    constr_strip_whitespace,
    str_validator,
    strict_str_validator,
    constr_length_validator
)
from pydantic.errors import AnyStrMaxLengthError, AnyStrMinLengthError, StrRegexError

from watcher.logic.exceptions import SlugInvalidValue, NameInvalidValue, SlugInvalidLength


T = TypeVar('T')


def optional(*fields):
    """Decorator function used to modify a pydantic model's fields to all be optional.
    Alternatively, you can  also pass the field names that should be made optional as arguments
    to the decorator.
    Taken from https://github.com/samuelcolvin/pydantic/issues/1223#issuecomment-775363074
    """
    def dec(_cls):
        for field in fields:
            _cls.__fields__[field].required = False
        return _cls

    if fields and inspect.isclass(fields[0]) and issubclass(fields[0], BaseModel):
        cls = fields[0]
        fields = cls.__fields__
        return dec(cls)

    return dec


class BaseSchema(BaseModel):

    def __init__(__pydantic_self__, **data: Any) -> None:
        super().__init__(**data)
        schema_fields = {field for field in __pydantic_self__.__class__.__fields__}
        data_fields = {field for field in data}
        unknown_fields = data_fields.difference(schema_fields)
        if unknown_fields:
            raise ValidationError(
                [
                    ErrorWrapper(
                        ValueError(f'Unknown fields: {unknown_fields}'),
                        loc=ROOT_KEY,
                    )
                ],
                None
            )

    def remove_fields(self, *fields: str) -> None:
        for field in fields:
            self.__dict__.pop(field, None)

    class Config(BaseConfig):
        orm_mode = True
        allow_population_by_field_name = True


class CursorPaginationResponse(GenericModel, Generic[T]):
    result: List[T]
    next: Optional[str]
    prev: Optional[str]


class WatcherConstrainedStr(ConstrainedStr):
    type_ = Optional[str]

    @classmethod
    def __get_validators__(cls):
        yield strict_str_validator if cls.strict else str_validator
        yield constr_strip_whitespace
        yield constr_lower
        yield cls.watcher_constr_length_validator
        yield cls.validate

    @classmethod
    def validate(cls, value: Union[str]) -> Union[str]:
        try:
            return super().validate(value=value)
        except StrRegexError:
            if cls.type_ == 'slug':
                raise SlugInvalidValue
            else:
                raise NameInvalidValue

    @classmethod
    def watcher_constr_length_validator(cls, value: Union[str], field: ModelField, config: BaseConfig) -> Union[str]:
        try:
            return constr_length_validator(value, field, config)
        except (AnyStrMaxLengthError, AnyStrMinLengthError):
            if cls.type_ == 'slug':
                raise SlugInvalidLength
            else:
                raise NameInvalidValue


def str_constr(**kwargs):
    if 'regex' in kwargs:
        kwargs['regex'] = re.compile(kwargs['regex'])
    return type('PatchedConstrainedStrValue', (WatcherConstrainedStr,), kwargs)


slug_type = str_constr(
    min_length=3, max_length=50,
    to_lower=True, strip_whitespace=True,
    regex=r'^[-\w]+$',
    type_='slug',
)

name_type = str_constr(
    min_length=1, max_length=150,
    strip_whitespace=True,
    type_='name'
)
