import collections.abc
import copy
from datetime import datetime
from typing import Any, Iterable, Mapping, Optional, Tuple, Type

from deprecated import deprecated
from sqlalchemy import Table
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.sql import ColumnElement, FromClause, Selectable, and_, asc, desc, literal, select
from sqlalchemy.sql.base import Executable
from sqlalchemy.sql.schema import Constraint

from .data_mapper import SelectableDataMapper, TableDataDumper
from .types import Entity, OptStrList, OptValuesList, OptValuesMapping, StrList


class RelationDescription:
    """
    Класс описания связанных таблиц для построения запросов.

    Если требуется изменить условия объединения объектов, то
    в наследнике класса надо переопределить функцию join_func
    """

    def __init__(self, *,
                 name: str,  # название связанного объекта
                 base: Table,  # базовый объект, с которым производится объединение
                 related: Table,  # объект, который присоединяется к базовому
                 base_cols: Iterable[str],  # колонки ForeignKey базового объекта
                 related_cols: Iterable[str],  # колонки ForeignKey присоединяемого объекта
                 mapper_cls: Optional[Type[SelectableDataMapper]] = None,  # маппер присоединяемого объекта
                 outer_join: bool = False,  # делать ли outer join (по умолчанию inner)
                 full_join: bool = False,  # делать ли full join
                 ):
        self.name = name
        self.base = base
        self.related = related
        self.mapper = mapper_cls(selectable=self.related, label_prefix=name) if mapper_cls else None
        self.base_cols = tuple(base_cols)
        self.related_cols = tuple(related_cols)
        self.outer_join = outer_join
        self.full_join = full_join

        assert len(self.base_cols) == len(self.related_cols)

    def get_column(self, field: str) -> ColumnElement:
        """Получить колонку присоединяемого объекта по имени"""
        return self.related.c[field]

    def get_relation_cond(self):
        return [
            self.base.c[bf] == self.related.c[tf]
            for bf, tf in zip(self.base_cols, self.related_cols)
        ]

    def join_func(self, from_clause: FromClause) -> FromClause:
        """Функция для построения from запроса"""
        return from_clause.join(
            self.related,
            onclause=and_(*self.get_relation_cond()),
            isouter=self.outer_join,
            full=self.full_join,
        )


class CRUDQueries:
    """
    Генератор типовых запросов.
    На основе переданного описания базовой и связанных сущностей позволяет
    генерировать запросы для типовых случаев получения и изменения данных.

    При указании полей в фильтрах, сортировке и т.п. для
    """

    def __init__(
        self,
        base: Table,
        id_fields: StrList = ('id',),
        mapper_cls: Type[SelectableDataMapper] = SelectableDataMapper,
        dumper_cls: Type[TableDataDumper] = TableDataDumper,
        related: Optional[Iterable[RelationDescription]] = None,
    ):
        self.base = base
        self._mapper = mapper_cls(selectable=self.base)
        self._dumper = dumper_cls(table=self.base)
        self._id_fields = id_fields
        self._related = {x.name: x for x in related or ()}

    def with_related(self, related: RelationDescription):  # type: ignore
        """
        Создать аналогичный QueryBuilder с информацией о связанной сущности.
        """
        cloned = copy.copy(self)  # sic: shallow copy
        cloned._related = dict(self._related.items())
        cloned._related[related.name] = related
        return cloned

    @property
    def id_fields(self):
        return self._id_fields

    def insert(self,
               obj: Entity,
               ignore_fields: OptStrList = None,
               on_conflict_do_update_constraint: Optional[Constraint] = None,
               on_conflict_do_update_set: Optional[dict] = None,
               ) -> Tuple[Executable, SelectableDataMapper]:
        """
        Генерация insert'а по переданному объекту. `ignore_vields` позволяет вставлять часть полей
        """
        values = self._dumper(obj, skip_fields=ignore_fields)
        query = (
            insert(self.base).
            values(**values).
            returning(*self._mapper.columns)
        )
        if on_conflict_do_update_constraint is not None:
            query = query.on_conflict_do_update(
                constraint=on_conflict_do_update_constraint,
                set_=values if on_conflict_do_update_set is None else {**values, **on_conflict_do_update_set},
            )
        return query, self._mapper

    def delete(self, obj: Entity, filter_fields: OptStrList = None, filters: OptValuesMapping = None) -> Executable:
        """
        Генерация delete по переданному объекту.
        `filter_fields` позволяет дополнительно задать какие поля должны быть равны удаляемому объекту
        `filters` позволяет задать произвольные дополнительные фильтры
        """
        query = (
            self.base.delete().
            returning(literal(True))
        )
        query = self._filter_object(query, obj, filter_fields, filters)
        return query

    def update(self, obj: Entity, filter_fields: OptStrList = None,
               filters: OptValuesMapping = None, only_fields: OptStrList = None,
               ignore_fields: OptStrList = None) -> Tuple[Executable, SelectableDataMapper]:
        """
        Генерация update по переданному объекту.
        `filter_fields` позволяет дополнительно задать какие поля должны быть равны удаляемому объекту
        `filters` позволяет задать произвольные дополнительные фильтры
        `only_fields` или `ignore_fields` ограничивают набор обновляемых полей
        """
        query = (
            self.base.update().
            values(**self._dumper(obj, keep_fields=only_fields, skip_fields=ignore_fields)).
            returning(*self._mapper.columns)
        )
        query = self._filter_object(query, obj, filter_fields, filters)
        return query, self._mapper

    def select(self, id_values: OptValuesList = None, filters: OptValuesMapping = None, order: OptStrList = None,
               for_update: bool = False, skip_locked: bool = False,
               offset: Optional[int] = None, limit: Optional[int] = None) -> Tuple[Selectable, SelectableDataMapper]:
        """
        Генерация select для запросов из основной сущности.
        `id_values` - значения для полей, перечисленных в параметре `id_fields` конструктора
        `filters` - фильтры для ограничения выборки
        `order` - список полей для сортировки, для сортировки в обратном порядке перед полем надо поставить "-"
        `for_update` - добавить к запросу with_for_update
        `skip_locked` - при указании `for_update` добавить skip_locked
        `offset`, `limit` - ограничить выборку
        """
        query, mapper, _ = self._select(id_values, filters, order, offset, limit, with_related=False)
        if for_update:
            query = query.with_for_update(skip_locked=skip_locked, key_share=True)
        return query, mapper

    def select_related(self, id_values: OptValuesList = None, filters: OptValuesMapping = None,
                       order: OptStrList = None, for_update: bool = False, skip_locked: bool = False,
                       offset: Optional[int] = None, limit: Optional[int] = None,
                       **kwargs: Any
                       ) -> Tuple[Selectable, SelectableDataMapper, Mapping[str, SelectableDataMapper]]:
        """
        Генерация select для запросов из основной сущности и related. В ответе дополнительно вернутся
        DataMapper'ы, через которые можно получить значения присоединенных сущностей
        `id_values` - значения для полей, перечисленных в параметре `id_fields` конструктора
        `filters` - фильтры для ограничения выборки
        `order` - список полей для сортировки, для сортировки в обратном порядке перед полем надо поставить "-"
        `for_update` - добавить к запросу with_for_update(of=self.base)
        `skip_locked` - при указании `for_update` добавить skip_locked
        `offset`, `limit` - ограничить выборку
        """
        query, mapper, rel_mappers = self._select(id_values, filters, order, offset, limit, with_related=True)
        if for_update:
            query = query.with_for_update(of=kwargs.get('lock_of', self.base), skip_locked=skip_locked, key_share=True)
        return query, mapper, rel_mappers

    def _select(self, id_values=None, filters=None, order=None, offset=None, limit=None, with_related=False):
        if with_related:
            from_clause, rel_mappers = self._make_from()
        else:
            from_clause = self.base
            rel_mappers = {}

        query = (
            select(sum((x.columns for x in rel_mappers.values() if x), self._mapper.columns)).
            select_from(from_clause)
        )

        if id_values:
            query = self._filter_query(query, {k: v for k, v in zip(self._id_fields, id_values)})
        if filters:
            query = self._filter_query(query, filters)
        if order:
            query = self._add_order(query, order)
        if offset is not None:
            query = query.offset(offset)
        if limit is not None:
            query = query.limit(limit)

        return query, self._mapper, rel_mappers

    def _filter_object(self, query, obj, filter_fields, filters):
        query = self._filter_object_id(query, obj)
        if filter_fields:
            query = self._filter_object_fields(query, obj, filter_fields)
        if filters:
            query = self._filter_query(query, filters)
        return query

    def _filter_object_id(self, query, obj):
        for field, value in self._dumper(obj, keep_fields=self._id_fields).items():
            query = query.where(self.base.c[field] == value)
        return query

    def _filter_object_fields(self, query, obj, fields):
        for field, value in self._dumper(obj, keep_fields=fields).items():
            query = query.where(self.base.c[field] == value)
        return query

    def _filter_query(self, query, filters):
        for field, value in filters.items():
            col = self._get_column(field)
            if callable(value):
                query = query.where(value(col))
            else:
                query = query.where(col == value)
        return query

    def _get_column(self, field):
        parts = field.split('.', maxsplit=1)
        if len(parts) < 2:
            return self.base.c[field]
        else:
            related, colname = parts
            return self._related[related].get_column(colname)

    def _make_from(self):
        from_clause = self.base
        rel_mappers = {}
        for relation in self._related.values():
            rel_mappers[relation.name] = relation.mapper
            from_clause = relation.join_func(from_clause)
        return from_clause, rel_mappers

    def _add_order(self, query, order):
        order_fields = []
        for order_field in order:
            order_dir = asc
            if order_field.startswith('-'):
                order_dir = desc
                order_field = order_field[1:]
            column = self._get_column(order_field)
            order_fields.append(order_dir(column))

        return query.order_by(*order_fields)


class Filters(collections.abc.MutableMapping):
    """
    Dictlike объект для описания фильтров с методами для более легкого формирования
    условий фильтрации
    """

    def __init__(self):
        self._filters = collections.defaultdict(list)

    def add_not_none(self, field, value, expr=None):
        if value is None:
            return
        if expr is None:
            self._filters[field].append(value)
        else:
            self._filters[field].append(expr)

    def add_range(self, field: str, from_: Optional[Any] = None, to_: Optional[Any] = None) -> None:
        """
        Добавить фильтр поля даты по диапазону `from_ <= field < to_`
        """
        if from_ is not None:
            self._filters[field].append(lambda field: field >= from_)
        if to_ is not None:
            self._filters[field].append(lambda field: field < to_)

    @deprecated
    def add_time_range(self, field: str, from_: Optional[datetime] = None, to_: Optional[datetime] = None) -> None:
        self.add_range(field, from_=from_, to_=to_)

    def __getitem__(self, key):
        flt = self._filters[key]
        if len(flt) == 1:
            return flt[0]
        elif len(flt) > 1:
            return self._all_filter(flt)
        else:
            raise KeyError

    def __setitem__(self, key, value):
        self._filters[key].append(value)

    def __delitem__(self, key):
        del self._filters[key]

    def __len__(self):
        return len(self._filters)

    def __iter__(self):
        return iter(self._filters)

    def _all_filter(self, filters):
        def _inner(col):
            result = []
            for flt in filters:
                if callable(flt):
                    result.append(flt(col))
                else:
                    result.append(col == flt)
            return and_(*result)

        return _inner
