from base64 import b64decode, b64encode
from enum import Enum
from typing import List, TypeVar, Generic, Callable, Iterable, Dict, Literal, Optional, Type, Any, Tuple, Set

from django.core.paginator import Paginator
from django.db.models import Q, QuerySet, Model
from pydantic import conint, Field
from pydantic import create_model
from pydantic.generics import GenericModel

from ninja import Schema
from wiki.api_v2.exceptions import BadRequest

T = TypeVar('T')
T2 = TypeVar('T2')


class OrderDirection(Enum):
    ASC = 'asc'
    DESC = 'desc'


PAGE_SIZE_DESC = 'Число результатов на странице выдачи.'
ORDER_BY_DESC = 'Если указано, отсортировать выдачу по полю в направлении `direction`'


class PaginationQuery(Schema):
    page_id: conint(ge=1) = Field(default=1, description='`legacy` Номер страницы выдачи')
    page_size: conint(ge=1, le=50) = Field(default=25, description=PAGE_SIZE_DESC)
    cursor: str = None


class MediumPaginationQuery(PaginationQuery):
    page_size: conint(ge=1, le=100) = Field(default=50, description=PAGE_SIZE_DESC)


class OrderingQuery(Schema):
    order_by: str = Field(None, description=ORDER_BY_DESC)
    order_direction: OrderDirection = Field(
        default=OrderDirection.ASC, description='Если указано поле `order_by`, направление сортировки'
    )


class Collection(GenericModel, Generic[T]):
    results: List[T]
    has_next: bool = Field(description='Для обратной совместимости, если задан курсор - то смотреть на `..._cursor`')
    page_id: int = Field(description='Для обратной совместимости, если задан курсор - всегда равен 1')

    next_cursor: Optional[str]
    prev_cursor: Optional[str]


class Cursor(Schema):
    type: Optional[str] = Field(default=None, alias='t')
    val: List[Any] = Field(default_factory=list, alias='v')
    is_prev: conint(ge=0, le=1) = Field(default=0, alias='i')
    pk: Any = Field(alias='p')

    @classmethod
    def decode(cls, encoded_cursor: bytes) -> 'Cursor':
        try:
            raw = b64decode(encoded_cursor).decode('utf-8')
            return cls.parse_raw(raw)
        except Exception:
            raise BadRequest('Invalid cursor')

    def encode(self) -> bytes:
        raw = self.json(by_alias=True, ensure_ascii=False, exclude_defaults=True, exclude_unset=True)
        return b64encode(raw.encode('utf-8'))


class CollectionFactory(Schema):
    name: str = ''

    order_fields: Dict[str, List[str]] = Field(default_factory=dict)
    order_fields_default: List[str] = Field(default_factory=list)
    order_direction: OrderDirection = OrderDirection.ASC

    def with_ordering(self, order_fields: Dict[str, List[str]]):
        self.order_fields = order_fields
        return self

    def default_ordering(self, field_names: List[str] = None, direction: OrderDirection = None):
        if field_names:
            self.order_fields_default = field_names
        if direction:
            self.order_direction = direction
        return self

    @property
    def ordering(self) -> Type[OrderingQuery]:
        variants = tuple(self.order_fields.keys())
        Klass = create_model(
            __model_name=self.name.capitalize() + OrderingQuery.__name__,
            __base__=OrderingQuery,
            order_by=(Literal[variants], Field(None, description=ORDER_BY_DESC)),  # noqa
        )
        return Klass

    @classmethod
    def pagination_build(
        cls,
        qs,
        serializer: Callable[[T2], T],
        filter: Callable[[List[T2]], Iterable[T2]] = None,
        pagination: PaginationQuery = None,
    ) -> 'Collection[T]':
        page_size = pagination.page_size
        page_id = pagination.page_id

        p = Paginator(qs, page_size)
        paginated_results = p.get_page(page_id)

        objects = paginated_results.object_list
        if filter:
            objects = filter(objects)

        return Collection(
            results=[serializer(orm) for orm in objects],
            has_next=paginated_results.has_next(),
            page_id=page_id,
        )

    @classmethod
    def build(
        cls,
        qs,
        serializer: Callable[[T2], T],
        filter: Callable[[List[T2]], Iterable[T2]] = None,
        pagination: PaginationQuery = None,
    ) -> 'Collection[T]':
        """Всегда неявно сортирует относительно `pk`"""
        return cls().ordered_build(qs=qs, serializer=serializer, filter=filter, pagination=pagination)

    def ordered_build(
        self,
        qs: QuerySet,
        pagination: PaginationQuery,
        ordering: OrderingQuery = None,
        serializer: Callable[[T2], T] = None,
        batch_serializer: Callable[[Iterable[T2]], Iterable[T]] = None,
        filter: Callable[[List[T2]], Iterable[T2]] = None,
    ) -> Collection[T]:
        """Всегда неявно сортирует относительно `pk`"""

        # поддерживаем две схемы: паджинация по страницам и сортировка с курсором
        legacy_pagination_mode = pagination.page_id > 1  # TODO удалить, после того как фронт перейдет на курсоры

        order_fields, order_direction, order_by = self.parse_ordering(ordering)
        increase = order_direction == OrderDirection.ASC
        reverse = False  # для разворота results, когда нужно получить предыдущие элементы

        if pagination.cursor and not legacy_pagination_mode:
            cursor = Cursor.decode(pagination.cursor)
            if ordering is not None and cursor.type != order_by:
                raise BadRequest('Invalid cursor')

            reverse = bool(cursor.is_prev)
            increase = reverse ^ increase  # True, когда нужны
            # (элементы после и сортировка по возрастанию) или (элементы до и сортировка по убыванию)

            filter_query = self.make_filter_query(cursor, order_fields, increase)
            qs = qs.filter(filter_query)

        order_params = self.make_order_params(order_fields, increase)
        qs = qs.order_by(*order_params)

        select_params = self.make_select_params(order_fields)
        if select_params:
            qs = qs.select_related(*select_params)

        if legacy_pagination_mode:
            return self.pagination_build(qs, serializer, filter, pagination)

        page_size, page_id = pagination.page_size, 0
        while True:  # для получения первых доступных элементов (из-за `filter`)
            objects = list(qs[page_id * page_size : (page_id + 1) * page_size + 1])
            has_next = len(objects) > page_size
            page_id += 1

            objects = objects[:page_size]
            if filter:
                objects = filter(objects)

            if has_next and not objects:
                continue
            else:
                break

        if reverse:
            objects.reverse()

        has_prev_cursor = has_next_cursor = False
        if objects:
            has_prev_cursor, has_next_cursor = bool(pagination.cursor), has_next
            if reverse:
                has_prev_cursor, has_next_cursor = has_next_cursor, has_prev_cursor

        prev_cursor = next_cursor = None
        order_by = ordering.order_by if ordering else None
        if has_prev_cursor:
            prev_cursor = self.create_cursor(objects[0], order_fields, order_by, is_prev=True).encode()
        if has_next_cursor:
            next_cursor = self.create_cursor(objects[-1], order_fields, order_by, is_prev=False).encode()

        return Collection(
            results=[serializer(orm) for orm in objects] if serializer else batch_serializer(objects),
            has_next=has_next,  # для обратной совместимости
            page_id=1,  # для обратной совместимости
            next_cursor=next_cursor,
            prev_cursor=prev_cursor,
        )

    def create_next_cursor(self, objs):
        if not objs:
            return None
        return self.create_cursor(objs[-1], self.order_fields_default, None, is_prev=False).encode()

    def create_prev_cursor(self, objs):
        if not objs:
            return None
        return self.create_cursor(objs[0], self.order_fields_default, None, is_prev=True).encode()

    def parse_ordering(self, ordering: Optional[OrderingQuery]) -> Tuple[List[str], OrderDirection, Optional[str]]:
        if ordering is None or ordering.order_by is None:  # means `default`
            order_fields = self.order_fields_default
            order_direction = self.order_direction
            order_by = None

        elif ordering.order_by in self.order_fields:
            order_fields = self.order_fields[ordering.order_by]
            order_direction = ordering.order_direction
            order_by = ordering.order_by

        else:
            allowed_columns = ', '.join([s for s in self.order_fields.keys()])
            raise BadRequest(f'Ordering by `{ordering.order_by}` is impossible; allowed: {allowed_columns}')

        return order_fields, order_direction, order_by

    @staticmethod
    def make_filter_query(cursor: Cursor, fields: List[str], increase: bool) -> Q:
        strict, soft = Q(), Q()
        suffix = '__gt' if increase else '__lt'

        if not fields:
            return Q(**{'pk' + suffix: cursor.pk})

        for field_name, value in zip(fields, cursor.val):
            strict &= Q(**{field_name + suffix: value})
            soft &= Q(**{field_name: value, 'pk' + suffix: cursor.pk})

        return strict | soft

    @staticmethod
    def make_order_params(order_fields: List[str], increase) -> List[str]:
        prefix = '' if increase else '-'
        fields = order_fields + ['pk']
        return [f'{prefix}{field}' for field in fields]

    @staticmethod
    def make_select_params(fields: List[str]) -> Set[str]:
        values = set()
        for field in fields:
            if field.count('__') > 1:
                value, _ = field.rsplit('__', 1)
                values.add(value)
        return values

    @staticmethod
    def create_cursor(obj: Model, fields: List[str], order_by: Optional[str], is_prev: bool) -> Cursor:
        if not fields:
            return Cursor(t=order_by, i=is_prev, p=obj.pk)

        values = []

        for field in fields:
            obj_ = obj
            for chunk in field.split('__'):
                obj_ = getattr(obj_, chunk)

            assert not isinstance(obj_, Model), 'Добавьте поля для фильтрации явно, не модели'
            values.append(obj_)

        return Cursor(t=order_by, v=values, i=is_prev, p=obj.pk)


TM = TypeVar('TM')
TS = TypeVar('TS')


class CompositeSerializer(Generic[TM, TS]):
    FIELDS = {
        # attr_name: list[ctx]
    }

    @classmethod
    def _add_fields(cls, q: TS, obj: TM, fields: List[str], ctx=None) -> TS:
        for field in fields:
            if field in cls.FIELDS:
                fn, deps = cls.FIELDS[field]
                if deps:
                    kwargs = {dep: ctx[dep] for dep in deps}
                    val = fn(obj, **kwargs)
                else:
                    val = fn(obj)
                setattr(q, field, val)
        return q

    @classmethod
    def extract_fields(cls, value):
        if value is None:
            return []
        validated_list = []
        for v in value.split(','):
            field = v.strip().lower()
            if field not in cls.FIELDS:
                allowed = ', '.join(cls.FIELDS.keys())
                raise BadRequest(f'Unknown field {field}; Allowed fields are {allowed}')
            validated_list.append(field)

        return validated_list
