# -*- coding: utf-8 -*-
from mpfs.dao.migration.migration import MigrationMode
from mpfs.metastorage.mongo.util import manual_route as mongo_manual_route, id_for_key
from mpfs.metastorage.postgres.query_executer import manual_route as pg_manual_route

from mpfs.dao.shard_endpoint import ShardType


class BaseRoutedShardClient(object):
    """Базовый клиент к шарду, предоставляющий операции для миграции
    коллекции для указанного шарда.

    Миграторам новых таблиц наследоваться напрямую от этого класса ЗАПРЕЩЕНО.
    Использовать BaseRoutedShardClientRefactored

    Каждый клиент должен определить следующие операции:
      * _find_by_uid
      * _remove_by_uid
      * _insert_chunk
      * _count_by_uid
    """
    def __init__(self, dao, shard_endpoint, mode=MigrationMode.REAL):
        self._dao_impl = self._get_dao_impl(dao, shard_endpoint)
        self._shard_type = shard_endpoint.get_type()
        self._shard_name = shard_endpoint.get_name()
        self._routing_context = self._get_routing_context(shard_endpoint)
        self._mode = mode

    def find_by_uid(self, uid):
        """Вернуть iterable всех данных пользователя ``uid`` для
        указанного шарда.
        """
        with self._routing_context(self._shard_name):
            return self._find_by_uid(uid)

    def count_by_uid(self, uid):
        """Вернуть количество записей пользователя ``uid`` на
        указанном шарде.
        """
        with self._routing_context(self._shard_name):
            return self._count_by_uid(uid)

    def remove_by_uid(self, uid):
        """Удалить все данные для пользователя ``uid`` на указанном шарде.
        """
        with self._routing_context(self._shard_name):
            return self._remove_by_uid(uid)

    def fix_before_migration(self, uid):
        """Поправить данные на исходном шарде перед миграцией (актуально для монги).
        """
        with self._routing_context(self._shard_name):
            return self._fix_by_uid(uid)

    def prepare_for_insert(self, uid):
        pass

    def insert_chunk(self, chunk):
        """Вставить ``chunk`` записей на указанный шард.
        """
        with self._routing_context(self._shard_name):
            return self._insert_chunk(chunk)

    def _find_by_uid(self, uid):
        return self._dao_impl.find({self.uid_field_name: uid})

    def _count_by_uid(self, uid):
        result = self._dao_impl.find({self.uid_field_name: uid})
        if isinstance(result, list):
            return len(result)
        return result.count()

    def _remove_by_uid(self, uid):
        if self._mode == MigrationMode.VALIDATE:
            return
        return self._dao_impl.remove({self.uid_field_name: uid})

    def _insert_chunk(self, chunk):
        for i in chunk:
            self._dao_impl.dao_item_cls._validate_mongo_dict(i)
        if self._mode == MigrationMode.VALIDATE:
            return
        return self._dao_impl.insert(chunk)

    def _fix_by_uid(self, uid):
        pass

    @property
    def uid_field_name(self):
        return self._dao_impl.dao_item_cls.uid_field_name

    @staticmethod
    def _find_tree_by_uid(uid, collection):
        root = collection.find_one_by_field(uid, {'_id': id_for_key(uid, '/')})
        if root is not None:
            yield root
        children = collection.iter_subtree(uid, '/')
        for child in children:
            yield child

    @staticmethod
    def _get_dao_impl(dao, shard_endpoint):
        """
        :type shard_endpoint: :class:`~ShardEndpoint`
        :rtype: :class:`~mpfs.core.dao.base.BaseDAOImplementation`
        """
        if shard_endpoint.get_type() is ShardType.MONGO:
            return dao.get_mongo_impl()
        return dao.get_pg_impl()

    @staticmethod
    def _get_routing_context(shard_endpoint):
        """
        :type shard_endpoint: :class:`~ShardEndpoint`
        """
        if shard_endpoint.get_type() is ShardType.MONGO:
            return mongo_manual_route
        return pg_manual_route


class BaseRoutedShardClientRefactored(BaseRoutedShardClient):

    def _count_by_uid(self, uid):
        return self._dao_impl.count_by_uid(uid)

    def _find_by_uid(self, uid):
        return self._dao_impl.fetch_by_uid(uid)

    def _remove_by_uid(self, uid):
        if self._mode == MigrationMode.VALIDATE:
            return
        return self._dao_impl.remove_by_uid(uid)

    def _insert_chunk(self, chunk):
        if self._mode == MigrationMode.VALIDATE:
            return
        if not chunk:
            return
        self._dao_impl.bulk_insert(chunk[0].uid, chunk)
