# -*- coding: utf-8 -*-
from collections import OrderedDict
from itertools import dropwhile

from django.conf import settings
from django.utils.encoding import smart_bytes
from passport.backend.oauth.core.db.eav.attributes import (
    attr_by_name,
    deserialize_attribute,
    serialize_attribute,
    VIRTUAL_ATTR_ENTITY_ID,
)
from passport.backend.oauth.core.db.eav.differ import (
    differ,
    is_in_diff,
)
from passport.backend.oauth.core.db.eav.errors import (
    AttributeNotFoundError,
    EntityNotFoundError,
)
from passport.backend.oauth.core.db.eav.query import (
    AddIndexQuery,
    CountByIndexQuery,
    DeleteAllAttributesQuery,
    DeleteAttributeQuery,
    DeleteIndexQuery,
    IncrementAutoIdQuery,
    SelectByIdQuery,
    SelectByIdsQuery,
    SelectByIndexQuery,
    SelectChunkQuery,
    SetAttributesQuery,
    UpdateIndexQuery,
)
from passport.backend.oauth.core.db.eav.sharder import (
    get_db_name,
    is_sharded,
)
from passport.backend.oauth.core.db.eav.transaction import Transaction


class EavModelDBMixin(object):
    """Чтение и запись модели в БД (eav- и индексные таблицы)"""
    _table = None
    _indexes = {}

    @classmethod
    def _parse_rows_from_db(cls, rows):
        result = OrderedDict()
        for id_, attr_type, serialized_value in rows:
            if id_ not in result:
                result[id_] = {}
            try:
                attr_name, attr_value = deserialize_attribute(cls.entity_name, id_, attr_type, serialized_value)
                result[id_][attr_name] = attr_value
            except AttributeNotFoundError:
                # В БД могут быть атрибуты, которые мы больше не хотим обрабатывать.
                # Или те, которые хотим, но ещё не умеем (например, выкладка нового
                # функционала ещё идёт). В обоих случаях падать не хотим.
                continue
        return result

    @classmethod
    def iterate_by_chunks(cls, chunk_size, last_processed_id=None, retries=None):
        """
        Итерируется по всем сущностям (в том числе, помеченным как удалённые),
        возвращает их чанками размера не более, чем chunk_size.
        Если указан last_processed_id, то листинг начинается с объекта, следующего за указанным.
        ВАЖНО: не гарантируются упорядоченность по id (может нарушаться при переходе
        между диапазонами шардирования) и точное соответствие размера чанка запрошенному.
        """
        from_id = None
        if is_sharded(cls.entity_name):
            db_names = settings.SHARD_NAMES
        else:
            db_names = [settings.CENTRAL_DB_NAME]

        if last_processed_id:
            # пропускаем уже обработанные сущности (и уже обработанные шарды)
            from_id = last_processed_id + 1
            last_processed_db_name = get_db_name(cls.entity_name, last_processed_id)
            db_names = list(dropwhile(
                lambda db_name: db_name != last_processed_db_name,
                db_names,
            ))

        for db_name in db_names:
            while True:
                query = SelectChunkQuery(
                    table=cls._table,
                    db_name=db_name,
                    from_id=from_id,
                    limit=chunk_size,
                )
                rows = query.execute(retries=retries).fetchall()
                result = cls._parse_rows_from_db(rows)
                objects = [
                    cls(entity_id=id_, **attributes)
                    for id_, attributes in result.items()
                ]
                if not objects:
                    from_id = 0
                    break  # выходим из внутреннего цикла, переходим к следующему db_name
                yield objects
                from_id = objects[-1].id + 1

    @classmethod
    def by_id(cls, entity_id, allow_deleted=False):
        """
        Возвращает объект с заданным id. Если его не существует (или он помечен как удалённый) -
        бросает исключение.
        """
        rows = SelectByIdQuery(
            table=cls._table,
            entity_name=cls.entity_name,
            entity_id=entity_id,
        ).execute().fetchall()
        result = cls._parse_rows_from_db(rows)
        if entity_id not in result:
            raise EntityNotFoundError('%s with id %d not found' % (cls.entity_name, entity_id))
        obj = cls(entity_id=entity_id, **result[entity_id])
        if not allow_deleted and getattr(obj, 'is_deleted', False):
            raise EntityNotFoundError('%s with id %d was deleted' % (cls.entity_name, entity_id))
        return obj

    @classmethod
    def by_ids(cls, entity_ids, allow_deleted=False):
        """
        Возвращает словарь (возможно, пустой) объектов с заданными id.
        :rtype: dict
        """
        # FIXME: не посылать запрос в шарды, где заведомо нет нужных сущностей
        if not entity_ids:
            return {}

        if is_sharded(cls.entity_name):
            db_names = settings.SHARD_NAMES
        else:
            db_names = [settings.CENTRAL_DB_NAME]

        queries = [
            SelectByIdsQuery(
                table=cls._table,
                db_name=db_name,
                entity_ids=entity_ids,
            )
            for db_name in db_names
        ]
        rows = sum([query.execute().fetchall() for query in queries], [])
        result = cls._parse_rows_from_db(rows)
        objects = dict(
            (id_, cls(entity_id=id_, **attributes))
            for id_, attributes in result.items()
        )
        return {
            id_: obj
            for id_, obj in objects.items()
            if allow_deleted or not getattr(obj, 'is_deleted', False)
        }

    @classmethod
    def by_index(cls, index_name, limit=None, offset=None, allow_deleted=False, **kwargs):
        """
        Возвращает список (возможно, пустой) объектов, удовлетворяющих заданным ограничениям.
        """
        if is_sharded(cls.entity_name):
            db_names = settings.SHARD_NAMES
        else:
            db_names = [settings.CENTRAL_DB_NAME]

        queries = [
            SelectByIndexQuery(
                table=cls._table,
                index=cls._indexes[index_name],
                db_name=db_name,
                limit=limit,
                offset=offset,
                **kwargs
            )
            for db_name in db_names
        ]
        rows = sum([query.execute().fetchall() for query in queries], [])
        result = cls._parse_rows_from_db(rows)
        objects = [
            cls(entity_id=id_, **attributes)
            for id_, attributes in result.items()
        ]
        return [
            obj
            for obj in objects
            if allow_deleted or not getattr(obj, 'is_deleted', False)
        ]

    @classmethod
    def count_by_index(cls, index_name, **kwargs):
        """
        Возвращает количество объектов (в том числе, удалённых),
        удовлетворяющих заданным ограничениям.
        """
        # TODO: возможно, стоит научить не учитывать удалённые объекты
        if is_sharded(cls.entity_name):
            db_names = settings.SHARD_NAMES
        else:
            db_names = [settings.CENTRAL_DB_NAME]

        queries = [
            CountByIndexQuery(
                table=cls._table,
                index=cls._indexes[index_name],
                db_name=db_name,
                **kwargs
            )
            for db_name in db_names
        ]
        return sum([query.execute().fetchone()[0] for query in queries], 0)

    def pre_save_created(self, generated_id):
        """
        Вызывается после того, как для свежесоздаваемого объекта сгенерировался id,
        но перед тем, как объект начал сериализоваться в БД.
        Может перегружаться в потомках.
        """
        self._id = generated_id

    def serialize(self, old):
        if old is None:
            # Создаётся новый объект, ему нужен id
            yield Transaction(
                queries=[
                    IncrementAutoIdQuery(self.entity_name),
                ],
                callback=lambda results: self.pre_save_created(generated_id=results[0].inserted_primary_key[0]),
            )

        diff = differ(old, self)
        queries = []

        # Обновляем атрибуты
        attr_types_values = dict(
            serialize_attribute(self.entity_name, self.id, attr_name, value)
            for attr_name, value in dict(diff['added'], **diff['changed']).items()
        )
        if attr_types_values:
            queries.append(SetAttributesQuery(
                table=self._table,
                entity_name=self.entity_name,
                entity_id=self.id,
                attr_types_values=attr_types_values,
            ))
        for attr_name in diff['removed']:
            attr_type, _ = attr_by_name(self.entity_name, attr_name)
            queries.append(DeleteAttributeQuery(
                table=self._table,
                entity_name=self.entity_name,
                entity_id=self.id,
                attr_type=attr_type,
            ))

        # Обновляем индексы
        for _, index in sorted(self._indexes.items()):
            key_field_values_changed = any(is_in_diff(field, diff) for field in index.key_fields)

            if not key_field_values_changed:
                # этот индекс по-прежнему актуален, его не трогаем
                continue

            old_key_values = None
            if old is not None:
                old_key_values = index.make_key_values_for_search(self.entity_name, self.id, old._attributes)

            new_index_values = index.make_all_key_values(self.entity_name, self.id, self._attributes)

            if old_key_values and new_index_values:
                queries.append(UpdateIndexQuery(
                    table=index.table,
                    entity_name=self.entity_name,
                    entity_id=self.id,
                    old_key_values=old_key_values,
                    new_all_values=new_index_values,
                ))
            elif old_key_values:
                queries.append(DeleteIndexQuery(
                    table=index.table,
                    entity_name=self.entity_name,
                    entity_id=self.id,
                    key_values=old_key_values,
                ))
            elif new_index_values:
                queries.append(AddIndexQuery(
                    table=index.table,
                    entity_name=self.entity_name,
                    entity_id=self.id,
                    all_values=new_index_values,
                ))

        yield Transaction(queries=queries)

    def delete(self):
        queries = []

        # Удаляем атрибуты
        queries.append(DeleteAllAttributesQuery(
            table=self._table,
            entity_name=self.entity_name,
            entity_id=self.id,
        ))

        # Удаляем индексы
        for _, index in sorted(self._indexes.items()):
            # TODO: удалять просто по id
            key_values = index.make_key_values_for_search(self.entity_name, self.id, self._attributes)
            if key_values:
                queries.append(DeleteIndexQuery(
                    table=index.table,
                    entity_name=self.entity_name,
                    entity_id=self.id,
                    key_values=key_values,
                ))

        # ID не переиспользуем - с таблицей auto_id ничего делать не надо
        yield Transaction(queries=queries)


class EavModelBlackboxMixin(object):
    @classmethod
    def _parse_attrs_from_bb(cls, entity_id, attrs_and_values):
        result = {}
        for attr_type, serialized_value in attrs_and_values.items():
            attr_type = int(attr_type)
            serialized_value = smart_bytes(serialized_value)
            try:
                attr_name, attr_value = deserialize_attribute(cls.entity_name, entity_id, attr_type, serialized_value)
                result[attr_name] = attr_value
            except AttributeNotFoundError:
                # В БД могут быть атрибуты, которые мы больше не хотим обрабатывать.
                # Или те, которые хотим, но ещё не умеем (например, выкладка нового
                # функционала ещё идёт). В обоих случаях падать не хотим.
                continue
        return result

    @classmethod
    def _parse(cls, oauth_block):
        """Возвращает словарь атрибутов"""
        raise NotImplementedError()  # pragma: no cover

    @classmethod
    def parse(cls, bb_response):
        """Парсит модель из успешного ответа метода oauth ЧЯ"""
        if 'oauth' not in bb_response:
            raise ValueError('Only method=oauth response can be parsed')
        oauth_block = bb_response['oauth']
        attrs_and_values = cls._parse(oauth_block)
        entity_id = int(attrs_and_values[str(VIRTUAL_ATTR_ENTITY_ID)])
        result = cls._parse_attrs_from_bb(entity_id=entity_id, attrs_and_values=attrs_and_values)
        return cls(entity_id=entity_id, **result)
