# -*- coding: utf-8 -*-
from passport.backend.oauth.core.db.eav.query_base import (
    AutoShardedQuery,
    CentralQuery,
    Query,
)
from passport.backend.oauth.core.db.eav.schemas import auto_id
from passport.backend.oauth.core.db.eav.utils import insert_on_duplicate_update
from sqlalchemy import (
    and_,
    func,
    select,
)


class SelectChunkQuery(Query):
    """
    Получает несколько сущностей, начиная с указанного id.
    Для шардированных сущностей этот запрос нужно направить в правильный шард.
    """
    def __init__(self, table, db_name, from_id=None, limit=None, **kwargs):
        super(SelectChunkQuery, self).__init__(table=table, db_name=db_name)
        self._values = kwargs
        self._limit = limit
        self._from_id = from_id or 0

    def to_sql(self):
        ids = select([
            self._table.c.id,
        ]).where(
            self._table.c.id >= self._from_id,
        ).limit(
            self._limit,
        ).distinct().alias()
        return select([
            self._table.c.id,
            self._table.c.type,
            self._table.c.value,
        ]).select_from(
            self._table.join(ids, self._table.c.id == ids.c.id),
        ).order_by(  # order by - для правильной сборки собранных атрибутов в модель
            self._table.c.id,
        )


class SelectByIdQuery(AutoShardedQuery):
    """Получает все атрибуты сущности по её id"""
    def __init__(self, table, entity_name, entity_id):
        super(SelectByIdQuery, self).__init__(table, entity_name, entity_id)
        self._values = {
            'id': entity_id,
        }

    def to_sql(self):
        return select([self._table]).where(self._table.c.id == self.values['id'])


class SelectByIdsQuery(Query):
    """Получает все атрибуты сущностей с заданными id"""
    def __init__(self, table, db_name, entity_ids):
        super(SelectByIdsQuery, self).__init__(table=table, db_name=db_name)
        self._values = {
            'ids': entity_ids,
        }

    def to_sql(self):
        return select([self._table]).where(self._table.c.id.in_(self.values['ids']))


class SelectByIndexQuery(Query):
    """
    Получает атрибуты сущностей, найденных по указанному индексу.
    Для шардированных сущностей подобный запрос нужно направить в каждый из шардов.
    """
    def __init__(self, table, index, db_name, limit=None, offset=None, **kwargs):
        super(SelectByIndexQuery, self).__init__(table=table, db_name=db_name)
        self._index = index
        self._values = kwargs
        self._limit = limit
        self._offset = offset

    def to_sql(self):
        ids = select([
            self._index._table.c.id,
        ]).where(
            self._index.make_clause(self.values),
        ).order_by(  # order by - для правильной пагинации с помощью limit и offset
            self._index._table.c.id,
        ).limit(
            self._limit,
        ).offset(
            self._offset,
        ).alias()
        return select([
            self._table.c.id,
            self._table.c.type,
            self._table.c.value,
        ]).select_from(
            self._table.join(ids, self._table.c.id == ids.c.id),
        ).order_by(  # order by - для правильной сборки собранных атрибутов в модель
            self._table.c.id,
        )


class CountByIndexQuery(Query):
    """
    Получает количество сущностей, найденных по указанному индексу.
    Для шардированных сущностей подобный запрос нужно направить в каждый из шардов.
    """
    def __init__(self, table, index, db_name, **kwargs):
        super(CountByIndexQuery, self).__init__(table=table, db_name=db_name)
        self._index = index
        self._values = kwargs

    def to_sql(self):
        return select([func.count()]).select_from(
            self._index._table,
        ).where(
            self._index.make_clause(self.values),
        )


class SetAttributesQuery(AutoShardedQuery):
    """Для заданной entity создаёт или апдейтит атрибуты"""
    def __init__(self, table, entity_name, entity_id, attr_types_values):
        super(SetAttributesQuery, self).__init__(table, entity_name, entity_id)
        self._values = [
            {
                'id': entity_id,
                'type': attr_type,
                'value': attr_value,
            } for attr_type, attr_value in sorted(attr_types_values.items())
        ]

    def to_sql(self):
        return insert_on_duplicate_update(self._table, ['value']).values(self.values)


class DeleteAttributeQuery(AutoShardedQuery):
    """Для заданной entity удаляет атрибут"""
    def __init__(self, table, entity_name, entity_id, attr_type):
        super(DeleteAttributeQuery, self).__init__(table, entity_name, entity_id)
        self._values = {
            'id': entity_id,
            'type': attr_type,
        }

    def to_sql(self):
        return self._table.delete().where(and_(
            self._table.c.id == self.values['id'],
            self._table.c.type == self.values['type'],
        ))


class DeleteAllAttributesQuery(AutoShardedQuery):
    """Для заданной entity удаляет атрибут"""
    def __init__(self, table, entity_name, entity_id):
        super(DeleteAllAttributesQuery, self).__init__(table, entity_name, entity_id)
        self._values = {
            'id': entity_id,
        }

    def to_sql(self):
        return self._table.delete().where(self._table.c.id == self.values['id'])


class AddIndexQuery(AutoShardedQuery):
    """Добавляет entity в индекс"""
    def __init__(self, table, entity_name, entity_id, all_values):
        super(AddIndexQuery, self).__init__(table, entity_name, entity_id)
        self._values = dict(
            all_values,
            id=self._entity_id,
        )

    def to_sql(self):
        return self._table.insert().values(self.values)


class UpdateIndexQuery(AutoShardedQuery):
    """Добавляет entity в индекс"""
    def __init__(self, table, entity_name, entity_id, old_key_values, new_all_values):
        super(UpdateIndexQuery, self).__init__(table, entity_name, entity_id)
        self._old_key_values = dict(
            old_key_values,
            id=self._entity_id,
        )
        self._new_all_values = dict(
            new_all_values,
            id=self._entity_id,
        )

    def to_sql(self):
        return self._table.update().where(and_(
            self._table.c[field] == self._old_key_values[field] for field in sorted(self._old_key_values)
        )).values(self._new_all_values)


class DeleteIndexQuery(AutoShardedQuery):
    """Удаляет entity из индекса"""
    def __init__(self, table, entity_name, entity_id, key_values):
        super(DeleteIndexQuery, self).__init__(table, entity_name, entity_id)
        self._values = dict(
            key_values,
            id=self._entity_id,
        )

    def to_sql(self):
        return self._table.delete().where(and_(
            self._table.c[field] == self.values[field] for field in sorted(self.values)
        ))


class IncrementAutoIdQuery(CentralQuery):
    """Получает текущий id для новой entity заданного типа"""
    # Не удаётся сделать одним апдейтом счётчика :(
    def __init__(self, entity_name):
        super(IncrementAutoIdQuery, self).__init__(table=auto_id[entity_name])

    def to_sql(self):
        return self._table.insert().values(self.values)
