# -*- coding: utf-8 -*-
import itertools
import random
import copy
from abc import ABCMeta, abstractmethod
from collections import defaultdict

from pymongo import ReadPreference as MongoReadPreference
from sqlalchemy.engine import RowProxy

import mpfs.engine.process
from mpfs.common.errors import UserIsReadOnly
from mpfs.common.util import flatten_dict
from mpfs.config import settings
from mpfs.dao.cursor import PostgresCursor, PostgresCursorChain
from mpfs.dao.fields import JsonField, CompressedJsonField, DAOItemField
from mpfs.dao.query_converter import MongoQueryConverter
from mpfs.dao.session import Session
from mpfs.dao.shard_endpoint import ShardType, ShardEndpoint
from mpfs.metastorage.postgres.query_executer import ReadPreference as PGReadPreference, PGQueryExecuter
from mpfs.metastorage.mongo.util import decompress_data, manual_route, parent_for_key, name_for_key

POSTGRES_RECONNECTION_ATTEMPTS = settings.postgres['reconnection_attempts']


def get_all_shard_endpoints():
    result = []
    for shard_name in mpfs.engine.process.dbctl().mapper.rspool.get_all_shards_names():
        result.append(ShardEndpoint(ShardType.MONGO, shard_name))
    for shard_endpoint in PGQueryExecuter().get_all_shard_endpoints():
        result.append(shard_endpoint)
    return result


class BaseDAOItemMetaclass(type):
    """Метакласс класса BaseDAOItem.

    Нужен, чтобы преобразовывать аттрибуты класса, которые отнаследованы от DAOItemField, в проперти, который в свою
    очередь вызывает функцию, умеющую доставать функции из словаря монги или словаря постгреса, а также вызывать
    конвертацию даных, чтобы они были одинаковыми и для данных из монги и для данных из постгреса.

    Пример использования:
        class StringField(DAOItemField):
            ...

        class FileDAOItem(BaseDAOItem):
            path = StringField(mongo_path='key', pg_path='path')

        item = FileDAOItem.create_from_mongo_dict(mongo_data_dict)
        item.path  # это строка, какого бы типа она не была в монге или в постгресе
    """

    def __new__(mcs, name, bases, class_dict):
        dao_fields = {k: v for k, v in
                      itertools.chain(class_dict.iteritems(), *[getattr(b, '_fields', {}).iteritems() for b in bases])
                      if isinstance(v, DAOItemField)}
        class_dict['_fields'] = dao_fields
        postgres_table_obj = class_dict.get('postgres_table_obj')
        if postgres_table_obj is not None and class_dict.get('columns_map') is None:
            class_dict['columns_map'] = {c.name: c for c in postgres_table_obj.columns}
        for k, v in dao_fields.iteritems():
            class_dict[k] = mcs.get_property(v)
        return super(BaseDAOItemMetaclass, mcs).__new__(mcs, name, bases, class_dict)

    @staticmethod
    def get_property(field_obj):
        def getter(self):
            return self._common_getter(field_obj)

        def setter(self, v):
            return self._common_setter(field_obj, v)

        return property(getter, setter)


class BaseDAOItem(object):
    """Базовый класс для объекта, который возвращает класс DAO (наследник класса BaseDAO).

    Используется для описания полей, которые мы достаем из базы, и их типов. Умеет доставать данные из могового словаря,
    из постгресового словаря, вызывать конвертацию из одного представления в другое.

    Пример использования:
        class FileDAOItem(BaseDAOItem):
            path = StringField(mongo_path='key', pg_path='path')

        item = FileDAOItem.create_from_mongo_dict(mongo_data_dict)
        item.get_postgres_representation()  # сконвертирует данные из монгового представления в постгресовое
    """

    __metaclass__ = BaseDAOItemMetaclass

    is_sharded = False
    is_migrated_to_postgres = False  # флаг только для нешардированных коллекций, для шардированных ходим в usermap
    is_mongo_readonly = False  # Флаг для включения ридонли на монгоколлекциях
    uid_field_name = 'uid'  # для шардированных коллекций в каком поле искать uid в монговом спеке
    mongo_collection_name = None
    """строка - имя коллекции в монге"""
    postgres_table_obj = None
    """объект типа sqlalchemy.sql.schema.Table, описывающий таблицу в постгресе"""

    mongo_compressed_field = 'zdata'
    validation_ignored_mongo_dict_fields = ()

    exclude_keys_after_conversion_to_mongo = None  # словарь-отображение, которое содержит пары вида ключ-значение,
    # которое надо выкинуть из результата get_mongo_representation, если документ достали из постгреса (если None, то
    # вернется все). Это нужно для того, чтобы не возвращать больше полей, чем лежало в монге, скажем, всякие пустые
    # public_hash и прочее. Пример можно посмотреть в FolderDAOItem.

    _fields = None

    def __init__(self):
        self._mongo_dict = None
        self._mongo_unpacked_zdata = None
        self._pg_data = None


    @classmethod
    def check_is_mongo_readonly(cls):
        if cls.is_mongo_readonly:
            return True

        if cls.mongo_collection_name not in settings.common_pg:
            return False

        return settings.common_pg[cls.mongo_collection_name].get('mongo_ro', False)


    @classmethod
    def get_field_mongo_representation(cls, field_name, value):
        return cls._fields[field_name].to_mongo(value)

    @classmethod
    def get_field_pg_representation(cls, field_name, value):
        return cls._fields[field_name].to_postgres(value)

    @classmethod
    def create_from_mongo_dict(cls, data, validate=False):
        if validate:
            cls._validate_mongo_dict(data)

        instance = cls()
        instance._mongo_dict = data
        return instance

    @classmethod
    def create_from_pg_data(cls, data):
        if isinstance(data, RowProxy):
            pg_data = cls._convert_row_proxy_to_pg_data(data)
        else:
            pg_data = data

        instance = cls()
        instance._pg_data = pg_data
        return instance

    @classmethod
    def create_from_raw_pg_dict(cls, pg_dict):
        pg_data = {}
        for column_name, column in cls.columns_map.iteritems():
            if column_name not in pg_dict:
                continue
            pg_data[column] = pg_dict[column_name]

        instance = cls()
        instance._pg_data = pg_data
        return instance

    def as_raw_pg_dict(self):
        return {c.name: v for c, v in self.get_postgres_representation().iteritems()}

    def get_mongo_representation(self, skip_missing_fields=False):
        if self._mongo_dict:
            return self._mongo_dict
        else:
            mongo_dict = {}
            common_values = self._common_bulk_getter()

            for field_name, field_obj in self._fields.iteritems():
                try:
                    dao_field_common_value = common_values[field_name]
                except KeyError:
                    if skip_missing_fields:
                        continue
                    raise
                except Exception as e:
                    raise type(e)(e.message + ' (field \'%s\')' % field_name)

                key_parts = field_obj.mongo_path.split('.')
                path_to_dict, dict_key = key_parts[:-1], key_parts[-1]

                cur_dict = mongo_dict

                for p in path_to_dict:
                    if p not in cur_dict:
                        cur_dict[p] = {}
                    cur_dict = cur_dict[p]

                if field_obj.default_value is None and dao_field_common_value == field_obj.default_value:
                    mongo_value = dao_field_common_value
                else:
                    mongo_value = field_obj.to_mongo(dao_field_common_value)
                if field_obj.mongo_item_parser:
                    mongo_value = field_obj.mongo_item_parser.format(mongo_value)

                if isinstance(cur_dict.get(dict_key), list) and isinstance(mongo_value, list):
                    # костыль специально для поддержки хранения стидов в монге в виде списка, будь он неладен
                    cur_dict[dict_key].extend(mongo_value)
                else:
                    cur_dict[dict_key] = mongo_value

            if self.exclude_keys_after_conversion_to_mongo:
                mongo_dict = self._exclude_values(mongo_dict, self.exclude_keys_after_conversion_to_mongo)

            return mongo_dict

    def get_postgres_representation(self, skip_missing_fields=False):
        if self._pg_data:
            return self._pg_data
        else:
            pg_data = {}
            common_values = self._common_bulk_getter()

            for field_name, field_obj in self._fields.iteritems():
                try:
                    dao_field_common_value = common_values[field_name]
                except KeyError:
                    if skip_missing_fields:
                        continue
                    raise
                except Exception as e:
                    raise type(e)(e.message + ' (field \'%s\')' % field_name)

                if field_obj.pg_path is not None:
                    if field_obj.default_value is None and dao_field_common_value == field_obj.default_value:
                        pg_value = dao_field_common_value
                    else:
                        pg_value = field_obj.to_postgres(dao_field_common_value)
                    pg_data[field_obj.pg_path] = pg_value
            return pg_data

    @classmethod
    def convert_postgres_value_to_mongo_for_coll(cls, coll, value):
        for field_name, field_obj in cls._fields.iteritems():
            if field_obj.pg_path.name == coll.name:
                if value is None:
                    return field_obj.mongo_path, None
                return field_obj.mongo_path, field_obj.to_mongo(field_obj.from_postgres(value))
        raise LookupError()

    @classmethod
    def convert_mongo_value_to_postgres_for_key(cls, key, value):
        for field_name, field_obj in cls._fields.iteritems():
            if field_obj.mongo_path == key:
                if value is None:
                    return field_obj.pg_path.name, None
                return field_obj.pg_path.name, field_obj.to_postgres(field_obj.from_mongo(value))
        raise LookupError()

    @classmethod
    def convert_mongo_key_to_postgres(cls, key):
        for field_name, field_obj in cls._fields.iteritems():
            if field_obj.mongo_path == key:
                if isinstance(field_obj.pg_path, str):
                    # кастомный случай, как для path в FileDAOItem - там мы не храним path в базе
                    return field_obj.pg_path
                return field_obj.pg_path.name
        raise LookupError()

    @classmethod
    def convert_from_postgres(cls, value_name, value):
        for field_name, field_obj in cls._fields.iteritems():
            if field_name != value_name:
                continue
            return field_obj.from_postgres(value)

    @classmethod
    def convert_from_mongo(cls, value_name, value):
        for field_name, field_obj in cls._fields.iteritems():
            if field_name != value_name:
                continue
            return field_obj.from_mongo(value)

    @classmethod
    def get_postgres_primary_key(cls):
        table = cls.postgres_table_obj
        if table.primary_key:
            return list(table.primary_key.columns)[0].name
        return None

    def copy(self):
        new_dao_item = self.__class__()
        if self._mongo_dict:
            new_dao_item._mongo_dict = copy.deepcopy(self._mongo_dict)
        elif self._pg_data:
            pg_data = {}
            for k, v in self._pg_data.iteritems():
                # buffer (bytea) поля не копируются deepcopy
                if not isinstance(v, buffer):
                    v = copy.deepcopy(v)
                pg_data[k] = v
            new_dao_item._pg_data = pg_data
        else:
            raise NotImplementedError()
        return new_dao_item

    @classmethod
    def _validate_mongo_dict(cls, mongo_dict):
        mongo_dict = copy.deepcopy(mongo_dict)

        if cls.mongo_compressed_field in mongo_dict:
            mongo_dict[cls.mongo_compressed_field] = decompress_data(mongo_dict[cls.mongo_compressed_field])

        flat_dict = flatten_dict(mongo_dict)
        mongo_paths = cls._fetch_paths()
        mongo_json_paths = cls._fetch_json_paths()

        for k, v in flat_dict.iteritems():
            if k in cls.validation_ignored_mongo_dict_fields:
                continue

            if k in mongo_paths:
                field = cls._get_field_by_mongo_path(k)
                if field.mongo_item_parser is not None:
                    v = field.mongo_item_parser.parse(v)
                field.validate_mongo_value(v)
                continue

            parts = k.split('.')
            k_parents = ['.'.join(parts[:e]) for e in xrange(1, len(parts) + 1)]
            if any(p in mongo_json_paths or p in cls.validation_ignored_mongo_dict_fields for p in k_parents):
                continue

            raise ValueError('Key "%s" not found in dao item class "%s"' % (k, cls.__name__))

    @classmethod
    def _fetch_paths(cls):
        paths = []
        for field_name, field_obj in cls._fields.iteritems():
            paths.append(field_obj.mongo_path)
        return paths

    @classmethod
    def _fetch_json_paths(cls):
        paths = []
        for field_name, field_obj in cls._fields.iteritems():
            if isinstance(field_obj, (JsonField, CompressedJsonField)):
                paths.append(field_obj.mongo_path)
        return paths

    @classmethod
    def _get_field_by_mongo_path(cls, mongo_path):
        for _, field_obj in cls._fields.iteritems():
            if field_obj.mongo_path == mongo_path:
                return field_obj
        raise LookupError('Field `%s` not found in class `%s`' % (mongo_path, cls.__name__))

    def _common_bulk_getter(self):
        """
        Функция написана так и не разбита на подфункции с целью оптимизации. Если ты разобьешь ее на читаемый код, то
        замедлятся все преобразования из монги в постгрес и обратно.
        """
        if not self._mongo_dict and not self._pg_data:
            raise ValueError('DAO item getter is called before item was initialized. Use '
                             'create_from_mongo_dict/create_from_pg_data for initialization.')
        if self._mongo_dict:
            common_data = {}
            for field_name, field_obj in self._fields.iteritems():
                try:
                    raw_value = self._mongo_getter(field_obj.mongo_path)
                    if field_obj.mongo_item_parser:
                        raw_value = field_obj.mongo_item_parser.parse(raw_value)

                    if field_obj.default_value is not DAOItemField.not_specified_default and \
                            type(raw_value) is type(field_obj.default_value) and \
                            raw_value == field_obj.default_value:
                        common_data[field_name] = raw_value
                    else:
                        common_data[field_name] = field_obj.from_mongo(raw_value)
                except KeyError:
                    if isinstance(field_obj.default_value, DAOItemField.NotSpecifiedDefault):
                        continue
                    common_data[field_name] = field_obj.default_value
                except Exception as e:
                    raise type(e)(e.message + ' (field \'%s\')' % field_name)
            return common_data
        elif self._pg_data:
            common_data = {}
            for field_name, field_obj in self._fields.iteritems():
                pg_field_name = field_obj.pg_path
                if pg_field_name not in self._pg_data:
                    continue

                raw_value = self._pg_data[pg_field_name]
                if raw_value is None:
                    common_data[field_name] = field_obj.default_value
                else:
                    try:
                        common_data[field_name] = field_obj.from_postgres(raw_value)
                    except Exception as e:
                        raise type(e)(e.message + ' (field \'%s\')' % field_name)
            return common_data

    def _common_getter(self, field_obj):
        if self._mongo_dict:
            try:
                raw_value = self._mongo_getter(field_obj.mongo_path)
                if field_obj.mongo_item_parser:
                    raw_value = field_obj.mongo_item_parser.parse(raw_value)

                if field_obj.default_value is not DAOItemField.not_specified_default and \
                        type(raw_value) is type(field_obj.default_value) and \
                        raw_value == field_obj.default_value:
                    return raw_value

                return field_obj.from_mongo(raw_value)
            except KeyError:
                if isinstance(field_obj.default_value, DAOItemField.NotSpecifiedDefault):
                    raise
                return field_obj.default_value
        elif self._pg_data:
            pg_field_name = field_obj.pg_path
            raw_value = self._pg_data[pg_field_name]
            if raw_value is None:
                return field_obj.default_value
            return field_obj.from_postgres(raw_value)
        else:
            raise ValueError('DAO item getter is called before item was initialized. Use '
                             'create_from_mongo_dict/create_from_pg_data for initialization.')

    def _common_setter(self, field_obj, field_value):
        if field_obj.default_value is not None and field_value is None:
            raise ValueError("Can't set 'None' to field with specified default_value. Field: %r" % field_obj)

        if self._pg_data is not None:
            converter = field_obj.to_postgres
            is_pg = True
        elif self._mongo_dict is not None:
            converter = field_obj.to_mongo
            is_pg = False
        else:
            self._pg_data = {}
            converter = field_obj.to_postgres
            is_pg = True

        if field_obj.default_value is None and field_value == field_obj.default_value:
            raw_value = field_value
        else:
            raw_value = converter(field_value)

        if is_pg:
            self._pg_data[field_obj.pg_path] = raw_value
        else:
            self._mongo_setter(field_obj.mongo_path, raw_value)

    def _mongo_getter(self, key):
        key_parts = key.split('.')
        current_item, key_parts = self._get_unpacked_item(key_parts)

        for k in key_parts:
            current_item = current_item[k]
        return current_item

    def _mongo_setter(self, key, value):
        key_parts = key.split('.')
        current_item, key_parts = self._get_unpacked_item(key_parts)
        last_key = key_parts[-1]
        key_parts = key_parts[:-1]

        for k in key_parts:
            try:
                current_item = current_item[k]
            except KeyError:
                current_item[k] = {}
                current_item = current_item[k]
        current_item[last_key] = value

    def _get_unpacked_item(self, key_parts):
        if len(key_parts) > 1 and key_parts[0] == self.mongo_compressed_field:
            # если запрашиваем item из zdata, то предварительно распаковываем ее (если еще не распаковали)
            # и сохраняем распакованные данные в отдельную переменную (чтобы не записать ее в базу потом обратно)
            if self._mongo_unpacked_zdata is None:
                self._mongo_unpacked_zdata = decompress_data(self._mongo_dict[self.mongo_compressed_field])
            return self._mongo_unpacked_zdata, key_parts[1:]
        return self._mongo_dict, key_parts

    def _exclude_values(self, mongo_dict, exclude_keys_for_value):
        result = {}
        for key, value in exclude_keys_for_value.iteritems():
            if key in mongo_dict and isinstance(value, dict):
                new_value = self._exclude_values(mongo_dict[key], value)
                if new_value:  # пустые словари нам не нужны
                    result[key] = new_value
            elif key in mongo_dict and mongo_dict[key] != value:
                result[key] = mongo_dict[key]
        for key in mongo_dict.viewkeys() - exclude_keys_for_value.viewkeys():
            result[key] = mongo_dict[key]
        return result

    @classmethod
    def _convert_row_proxy_to_pg_data(cls, raw_query):
        result = {}
        for key, value in raw_query.items():
            # игнорируем поля из PG, которых нет в DaoItem
            try:
                result[cls.columns_map[key]] = value
            except KeyError:
                pass
        return result

    def __repr__(self):
        result = ''
        for key in self._fields.iterkeys():
            try:
                value = getattr(self, key)
            except KeyError:
                value = '<NOT SET>'
            result += "%s=%r, " % (key, value)
        if result:
            result = result[:-2]
        return "<%s(%s)>" % (self.__class__.__name__, result)


class MongoHelper(object):
    """Вспомогательный класс для работы с монговыми коллекциями из классов-наследников от BaseDAO.

    Единственный способ для работы с монговой базой. Обертка, чтобы ничего не звалось напрямую.
    Содержит в себе инстанс монговой базы данных, умеет доставать из нее коллекции.
    """
    storage_name_to_collection_name = {
        'additional': 'additional_data',
        'disk': 'user_data',
        'trash': 'trash',
        'hidden': 'hidden_data',
        'notes': 'notes_data',
        'photounlim': 'photounlim_data',
        'attach': 'attach_data',
        'narod': 'narod_data',
        'client': 'client_data',
    }

    def __init__(self):
        self._dbctl = mpfs.engine.process.dbctl()
        self._db = self._dbctl.database()

    def get_collection(self, name, shard_name=None):
        if shard_name is not None:
            return self._dbctl.mapper.get_collection_for_rs(name, shard_name)
        return self._db[name]

    def get_routed_collection_for_uid(self, uid, coll_name):
        mapper = self._dbctl.mapper
        rs_name = mapper.get_manual_route_shard_name()
        if rs_name is None:
            rs_name = mapper.get_rsname_for_uid(uid)
        return mapper.get_collection_for_rs(coll_name, rs_name)

    def get_routed_collection_for_shard(self, shard_name, coll_name):
        return self._dbctl.mapper.get_collection_for_rs(coll_name, shard_name)

    def get_routed_collection_for_shard_endpoint(self, shard_endpoint, coll_name):
        if not shard_endpoint.is_mongo():
            raise ValueError('Except mongo shard endpoint. Got %s' % shard_endpoint)
        return self.get_routed_collection_for_shard(shard_endpoint.get_name(), coll_name)

    def get_collection_by_address(self, uid, address):
        return self.get_routed_collection_for_uid(uid, self.storage_name_to_collection_name[address.storage_name])

    def iter_over_all_shards(self, coll_name):
        rs_names = self._dbctl.mapper.rspool.get_all_shards_names()
        # более равномерная нагрузка на шарды
        random.shuffle(rs_names)
        for rs_name in rs_names:
            yield self._dbctl.mapper.get_collection_for_rs(coll_name, rs_name)


class QueryWithParams(object):
    def __init__(self, query, params):
        self.query = query
        self.params = params


def convert_mongo_read_preference(read_preference):
    if read_preference == MongoReadPreference.PRIMARY:
        return PGReadPreference.primary
    elif read_preference == MongoReadPreference.PRIMARY_PREFERRED:
        return PGReadPreference.primary_preferred
    elif read_preference == MongoReadPreference.SECONDARY:
        return PGReadPreference.secondary
    elif read_preference == MongoReadPreference.SECONDARY_PREFERRED:
        return PGReadPreference.secondary_preferred
    return PGReadPreference.primary


class BaseDAOImplementation(object):
    __metaclass__ = ABCMeta

    @abstractmethod
    def find(self, spec=None, fields=None, skip=0, limit=0, sort=None, **kwargs):
        raise NotImplementedError()

    def find_one(self, spec=None, fields=None, **kwargs):
        cursor = self.find(spec, fields, **kwargs)
        return next(iter(cursor), None)

    @abstractmethod
    def find_on_shard(self, spec=None, fields=None, skip=0, limit=0, sort=None, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def insert(self, doc_or_docs, manipulate=True, continue_on_error=False, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def remove(self, spec_or_id=None, multi=True, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def update(self, spec, document, upsert=False, multi=False, **kwargs):
        raise NotImplementedError()

    def find_one_on_shard(self, spec=None, fields=None, shard_name=None, **kwargs):
        cursor = self.find_on_shard(spec, fields, shard_name=shard_name, **kwargs)
        return next(iter(cursor), None)

    def get_read_preference(self, kwargs):
        read_preference = None
        if 'read_preference' in kwargs:
            read_preference = kwargs['read_preference']
        else:
            global_read_preference = mpfs.engine.process.get_read_preference()
            if global_read_preference is not None:
                read_preference = global_read_preference
        return read_preference


def check_collection_readonly(func):
    def wrapped(dao, *args, **kwargs):
        if isinstance(dao, MongoBaseDAOImplementation) and dao.dao_item_cls.check_is_mongo_readonly():
            raise UserIsReadOnly('Collection %s is in readonly mode' % dao.dao_item_cls.mongo_collection_name)
        return func(dao, *args, **kwargs)

    return wrapped


class MongoBaseDAOImplementation(BaseDAOImplementation):
    def __init__(self, dao_item_cls):
        self._mongo_helper = MongoHelper()
        self.dao_item_cls = dao_item_cls

    def doc_to_item(self, mongo_doc):
        """Преобразовать один документ в dao item.

        :rtype: (None|DAOItem)
        """
        if mongo_doc:
            return self.dao_item_cls.create_from_mongo_dict(mongo_doc)
        return None

    def get_collection_by_shard_name(self, shard_name):
        """Получить монго-коллекцию по имени шарда

        В новом коде лучше использовать get_collection_by_shard_endpoint
        """
        return self._mongo_helper.get_routed_collection_for_shard(
            shard_name,
            self.dao_item_cls.mongo_collection_name
        )

    def get_collection_by_shard_endpoint(self, shard_endpoint):
        """Получить монго-коллекцию по ShardEndpoint"""
        return self._mongo_helper.get_routed_collection_for_shard_endpoint(
            shard_endpoint,
            self.dao_item_cls.mongo_collection_name
        )

    def get_collection_by_uid(self, uid):
        """Получить монго-коллекцию для шарда, на котором живет uid"""
        return self._mongo_helper.get_routed_collection_for_uid(
            uid,
            self.dao_item_cls.mongo_collection_name
        )

    def get_field_repr(self, field_name, field_value):
        return self.dao_item_cls.get_field_mongo_representation(field_name, field_value)

    def find(self, spec=None, fields=None, skip=0, limit=0, sort=None, **kwargs):
        coll = self._mongo_helper.get_collection(self.dao_item_cls.mongo_collection_name)
        return self._find_in_routed_collection(coll, spec, fields, skip, limit, sort, **kwargs)

    def find_and_modify(self, spec=None, update=None, fields=None, skip=0, limit=0, sort=None, **kwargs):
        coll = self._mongo_helper.get_collection(self.dao_item_cls.mongo_collection_name)
        return self._find_and_modify_in_routed_collection(coll, spec, update, fields, skip, limit, sort, **kwargs)

    @check_collection_readonly
    def insert(self, doc_or_docs, manipulate=True, continue_on_error=False, **kwargs):
        kwargs.update({
            'manipulate': manipulate,
            'continue_on_error': continue_on_error
        })
        return self._mongo_helper.get_collection(self.dao_item_cls.mongo_collection_name).insert(doc_or_docs, **kwargs)

    @check_collection_readonly
    def remove(self, spec_or_id=None, multi=True, **kwargs):
        kwargs.update({
            'multi': multi
        })
        return self._mongo_helper.get_collection(self.dao_item_cls.mongo_collection_name).remove(spec_or_id, **kwargs)

    @check_collection_readonly
    def update(self, spec, document, upsert=False, multi=False, **kwargs):
        kwargs.update({
            'document': document,
            'upsert': upsert,
            'multi': multi
        })
        return self._mongo_helper.get_collection(self.dao_item_cls.mongo_collection_name).update(spec, **kwargs)

    def find_on_shard(self, spec=None, fields=None, skip=0, limit=0, sort=None, shard_name=None, **kwargs):
        if not shard_name:
            raise ValueError('Shard name is not specified')
        coll = self._mongo_helper.get_collection(self.dao_item_cls.mongo_collection_name, shard_name)
        return self._find_in_routed_collection(coll, spec, fields, skip, limit, sort, **kwargs)

    @staticmethod
    def _find_in_routed_collection(collection, spec=None, fields=None, skip=0, limit=0, sort=None, **kwargs):
        kwargs.update({
            'fields': fields,
            'skip': skip,
            'limit': limit,
            'sort': sort,
        })
        return collection.find(spec, **kwargs)

    @staticmethod
    def _find_and_modify_in_routed_collection(collection, spec=None, update=None, fields=None, skip=0, limit=0,
                                              sort=None, **kwargs):
        kwargs.update({
            'fields': fields,
            'skip': skip,
            'limit': limit,
            'sort': sort,
        })
        return collection.find_and_modify(spec, update, **kwargs)

    def count(self):
        return self.find().count()


class PostgresBaseDAOImplementation(BaseDAOImplementation):
    def __init__(self, dao_item_cls, session=None):
        self.dao_item_cls = dao_item_cls
        self.session = session

    def get_session(self, uid):
        session = self.session
        if session is None:
            session = Session.create_from_uid(uid)
        return session

    def get_field_repr(self, field_name, field_value):
        return self.dao_item_cls.get_field_pg_representation(field_name, field_value)

    def fetch_one_item(self, result_proxy):
        """Получить один dao item.

        :rtype: (None|DAOItem)
        """
        doc = result_proxy.fetchone()
        return self.doc_to_item(doc)

    def doc_to_item(self, pg_doc):
        """Преобразовать один документ в dao item.

        :rtype: (None|DAOItem)
        """
        if pg_doc:
            return self.dao_item_cls.create_from_pg_data(pg_doc)
        return None

    def find(self, spec=None, fields=None, skip=0, limit=0, sort=None, **kwargs):
        query, params = MongoQueryConverter(self.dao_item_cls).find_to_sql(spec=spec, fields=fields, skip=skip,
                                                                           limit=limit, sort=sort)
        read_preference = self.get_read_preference(kwargs)
        if read_preference is None:
            session = self._get_session(spec)
        else:
            session = self._get_session(spec, read_preference=convert_mongo_read_preference(read_preference))

        return PostgresCursor(session, QueryWithParams(query, params), self.dao_item_cls)

    def count(self, spec=None, fields=None, skip=0, limit=0, sort=None, **kwargs):
        return self.find(spec=spec, field=fields, skip=skip, limit=limit, sort=sort, **kwargs).count()

    def find_on_shard(self, spec=None, fields=None, skip=0, limit=0, sort=None, shard_name=None, **kwargs):
        if not shard_name:
            raise ValueError('Shard name is not specified')
        query, params = MongoQueryConverter(self.dao_item_cls).find_to_sql(spec=spec, fields=fields, skip=skip,
                                                                           limit=limit, sort=sort)

        read_preference = self.get_read_preference(kwargs)
        if read_preference is None:
            read_preference = PGReadPreference.primary_preferred
        else:
            read_preference = convert_mongo_read_preference(read_preference)

        try:
            uid = int(spec['uid'])
        except Exception:
            session = Session.create_from_shard_id(shard_name, read_preference=read_preference)
        else:
            session = Session.create_from_shard_id(shard_name, ucache_hint_uid=uid, read_preference=read_preference)

        return PostgresCursor(session, QueryWithParams(query, params), self.dao_item_cls)

    def count(self, spec=None, fields=None, skip=0, limit=0, sort=None, **kwargs):
        return self.find(spec=spec, field=fields, skip=skip, limit=limit, sort=sort, **kwargs).count()

    def insert(self, doc_or_docs, manipulate=True, continue_on_error=False, **kwargs):
        if not isinstance(doc_or_docs, list):
            documents = [doc_or_docs]
        else:
            documents = doc_or_docs

        if not self.dao_item_cls.is_sharded:
            docs_enumerated = enumerate(documents)
            session = self._get_session()
            ids_enumerated = self._insert(session, continue_on_error, docs_enumerated, manipulate)
        else:
            # тут в разных документах могут быть вставки на разные шарды - группируем по uid'ам
            docs_enumerated = enumerate(documents)
            uid_to_docs = self._group_docs_enumerated_by_uid(docs_enumerated)

            ids_enumerated = []
            for uid, docs in uid_to_docs.iteritems():
                session = self._get_session(uid)
                ids_enumerated += self._insert(session, continue_on_error, docs, manipulate)

        if manipulate:
            if not isinstance(doc_or_docs, list):
                return ids_enumerated[0][1]
            else:
                ids_enumerated.sort()
                return [item[1] for item in ids_enumerated]
        else:
            return None

    def remove(self, spec_or_id=None, multi=True, **kwargs):
        session = self._get_session(spec_or_id)

        query, params = MongoQueryConverter(self.dao_item_cls).remove_to_sql(spec_or_id, multi=multi)

        with session.begin():
            result = session.execute(query, params)
            return {'n': result.rowcount}  # эмулируем ответ от монги (там больше значений, но мы их не используем)

    def _pg_update_result_to_mongo(self, pg_result, upsert=False):
        rows_updated = pg_result.rowcount
        if upsert:
            if rows_updated:
                return {'n': rows_updated, 'updatedExisting': True}
            else:
                return {'n': 1, 'updatedExisting': False}
        else:
            return {'n': rows_updated, 'updatedExisting': rows_updated > 0}

    def update(self, spec, document, upsert=False, multi=False, **kwargs):
        session = self._get_session(spec)

        query, params = MongoQueryConverter(self.dao_item_cls).update_to_sql(spec, document, upsert, multi)

        with session.begin():
            result = session.execute(query, params)
            return self._pg_update_result_to_mongo(result, upsert)

    def _insert(self, session, continue_on_error, docs_enumerated, manipulate):
        # тут все документы разбиваем на две пачки - те, у которых есть _id и те, которым надо их сгенерировать
        docs_with_id_enumerated, docs_without_id_enumerated = self._group_docs_by_having_id_enumerated(docs_enumerated)

        docs_with_id = [d[1] for d in docs_with_id_enumerated]
        docs_without_id = [d[1] for d in docs_without_id_enumerated]

        ids_enumerated = []
        if docs_with_id:
            query, params = MongoQueryConverter(self.dao_item_cls).insert_to_sql(
                docs_with_id,
                continue_on_error,
                manipulate=True,
                contains_id_field=True
            )
            with session.begin():
                cursor = session.execute(query, params)
                if manipulate:
                    pk_name = self.dao_item_cls.get_postgres_primary_key()
                    if pk_name is None:
                        pk_name = '_id'
                    coll = self.dao_item_cls.postgres_table_obj.c[pk_name]
                    for num, row in enumerate(cursor):
                        _, id_ = self.dao_item_cls.convert_postgres_value_to_mongo_for_coll(coll, row[pk_name])
                        index = docs_with_id_enumerated[num][0]
                        ids_enumerated.append((index, id_))

        if docs_without_id:
            query, params = MongoQueryConverter(self.dao_item_cls).insert_to_sql(
                docs_without_id,
                continue_on_error,
                manipulate=True,
                contains_id_field=False
            )
            with session.begin():
                cursor = session.execute(query, params)
                if manipulate:
                    pk_name = self.dao_item_cls.get_postgres_primary_key()
                    if pk_name is None:
                        pk_name = '_id'
                    coll = self.dao_item_cls.postgres_table_obj.c[pk_name]
                    for num, row in enumerate(cursor):
                        _, id_ = self.dao_item_cls.convert_postgres_value_to_mongo_for_coll(coll, row[pk_name])
                        index = docs_without_id_enumerated[num][0]
                        ids_enumerated.append((index, id_))

        return ids_enumerated

    def _group_docs_enumerated_by_uid(self, docs_enumerated):
        """На вход получаем список документов, возвращаем map вида [(uid, [(num, doc)])]
        """
        uid_to_docs_enumerated = defaultdict(list)
        for num, doc in docs_enumerated:
            uid = doc[self.dao_item_cls.uid_field_name]
            uid_to_docs_enumerated[uid].append((num, doc))
        return uid_to_docs_enumerated

    @staticmethod
    def _group_docs_by_having_id_enumerated(docs_enumerated):
        """На вход получаем список пар вида <num, doc>, на выход выдаем два списка.
         Первый - список документов, содержащих _id, второй - нет. При этом элемент списка состоит из пары вида
         <num, doc>, где num - его индекс в изначальном списке. Нужно для склеивания результата.
        """
        docs_with_id = []
        docs_without_id = []

        for num, doc in docs_enumerated:
            if '_id' in doc:
                docs_with_id.append((num, doc))
            else:
                docs_without_id.append((num, doc))

        return docs_with_id, docs_without_id

    def _get_session(self, spec_or_uid=None, read_preference=PGReadPreference.primary):
        if not self.dao_item_cls.is_sharded:
            session = Session.create_common_shard(read_preference=read_preference)
        else:
            if isinstance(spec_or_uid, dict):
                assert self.dao_item_cls.uid_field_name in spec_or_uid
                uid = spec_or_uid[self.dao_item_cls.uid_field_name]
            else:
                uid = spec_or_uid
            session = Session.create_from_uid(uid, read_preference=read_preference)
        return session


class BaseDAO(object):
    """
    Базовый класс для DAO объектов.
    Нужен для обеспечения доступа к БД посредством функций find/update/remove/insert.

    Наследник должен определять аттрибут dao_item_cls, в котором необходимо указать объект класса-наследника от
    BaseDAOItem.

    :type dao_item_cls: BaseDAOItem
    """

    dao_item_cls = None
    allowed_write_options = {'w', 'fsync'}

    def __init__(self, session=None):
        self._mongo_impl = MongoBaseDAOImplementation(self.dao_item_cls)
        self._pg_impl = PostgresBaseDAOImplementation(self.dao_item_cls)
        self.session = session

    def get_session(self, uid):
        session = self.session
        if session is None:
            session = Session.create_from_uid(uid)
        return session

    @classmethod
    def _check_write_options(cls, options):
        invalid_options = options.viewkeys() - cls.allowed_write_options
        if invalid_options:
            raise TypeError('%s are an invalid keyword arguments for this function' % list(invalid_options))

    def _is_migrated_to_postgres(self):
        if self.dao_item_cls.is_migrated_to_postgres:
            return True

        if self.dao_item_cls.mongo_collection_name not in settings.common_pg:
            return False

        return settings.common_pg[self.dao_item_cls.mongo_collection_name].get('use_pg', False)

    def is_readonly(self):
        if self._is_migrated_to_postgres():
            return False

        return self.dao_item_cls.check_is_mongo_readonly()

    def check_readonly(self):
        if self.is_readonly():
            raise UserIsReadOnly('%s is in readonly mode' % self.dao_item_cls.mongo_collection_name)

    def _get_impl(self, item):
        if self.dao_item_cls.is_sharded:
            if isinstance(item, dict):
                uid = item[self.dao_item_cls.uid_field_name]
                if isinstance(uid, dict) and len(uid) == 1 and uid.keys()[0] == '$in':
                    uids = set(uid['$in'])
                else:
                    uids = {uid}
            elif isinstance(item, list):
                uids = {i[self.dao_item_cls.uid_field_name] for i in item}
            elif isinstance(item, basestring):
                uids = {str(item)}
            else:
                raise RuntimeError('could not get implementation for item "%s"' % item)

            # тут будет хитрая проверка: если все uid'ы из списка в монге или в постгресе, то достаем правильную
            # имплементацию и зовем на ней вставку, иначе - райзим исключение, что не поддерживаем - пока не надо
            # если ты читаешь этот код и понимаешь, что поддержка нужна - добавь ее
            coll_in_postgres_by_uid = [
                mpfs.engine.process.usrctl().is_collection_in_postgres(uid, self.dao_item_cls.mongo_collection_name)
                for uid in uids
            ]
            if all(coll_in_postgres_by_uid):
                return self._pg_impl
            elif all([i is False for i in coll_in_postgres_by_uid]):
                return self._mongo_impl

            raise RuntimeError('item "%s" contains users from mongo and postgres, we do not support it' % item)
        else:
            if self._is_migrated_to_postgres():
                return self._pg_impl
            return self._mongo_impl

    def _get_impl_by_shard_endpoint(self, shard_endpoint):
        if shard_endpoint.is_pg():
            return self._pg_impl
        return self._mongo_impl

    def _get_impl_by_shard(self, shard_name):
        # TODO: грустный способ определять постгресовый шард это или нет по имени, надо бы поправить
        # Используй _get_impl_by_shard_endpoint
        is_postgres_shard = not shard_name.startswith('disk')
        if is_postgres_shard:
            return self._pg_impl
        return self._mongo_impl

    def _get_impl_by_uid(self, uid):
        if mpfs.engine.process.usrctl().is_user_in_postgres(uid):
            return self._pg_impl
        return self._mongo_impl

    def get_pg_impl(self):
        return self._pg_impl

    def get_mongo_impl(self):
        return self._mongo_impl

    def find(self, spec=None, fields=None, skip=0, limit=0, sort=None, **kwargs):
        impl = self._get_impl(spec)
        return impl.find(spec, fields, skip, limit, sort, **kwargs)

    def find_one(self, spec=None, fields=None, **kwargs):
        impl = self._get_impl(spec)
        return impl.find_one(spec, fields, **kwargs)

    def find_on_shard(self, spec=None, fields=None, skip=0, limit=0, sort=None, shard_name=None, **kwargs):
        impl = self._get_impl_by_shard(shard_name)
        return impl.find_on_shard(spec, fields, skip, limit, sort, shard_name, **kwargs)

    def insert(self, doc_or_docs, manipulate=True, continue_on_error=False, **kwargs):
        self._check_write_options(kwargs)
        impl = self._get_impl(doc_or_docs)
        return impl.insert(doc_or_docs, manipulate, continue_on_error, **kwargs)

    def remove(self, spec_or_id=None, multi=True, **kwargs):
        self._check_write_options(kwargs)
        impl = self._get_impl(spec_or_id)
        return impl.remove(spec_or_id, multi, **kwargs)

    def update(self, spec, document, upsert=False, multi=False, **kwargs):
        self._check_write_options(kwargs)
        impl = self._get_impl(spec)
        return impl.update(spec, document, upsert, multi, **kwargs)

    def find_one_on_shard(self, spec=None, fields=None, shard_name=None, **kwargs):
        cursor = self.find_on_shard(spec, fields, shard_name=shard_name, limit=1, **kwargs)
        return next(iter(cursor), None)


class FakeColumn(object):
    """
    Класс для имитации sqlalchemy.sql.schema.Column
    Нужен для таких случаев, когда мы не храним какое-то поле в базе, но хотим его вернуть. Например, это поле "path" и
    таблица "folders". Мы возвращаем его в запросах к получению папки, но реально в таблице его нет.
    """

    def __init__(self, name):
        self.name = name

    def __eq__(self, other):
        if isinstance(other, FakeColumn):
            return self.name == other.name
        elif isinstance(other, basestring):
            return self.name == other
        return False

    def __hash__(self):
        return hash(self.name)

    def __str__(self):
        return 'FakeColumn(%s)' % self.name


class ValuesTemplateGenerator(object):
    """
    Генератор переменных и параметров для SQL запросов

    Использование:
    >>> uid = 1111
    >>> values = [{'name': '1.txt', 'uid': 12345}, {'name': '2.txt', 'uid': 54321},]
    >>> template_generator = ValuesTemplateGenerator(('name', 'uid', ))
    >>> values_template = template_generator.get_values_tmpl(len(values))
    >>> values_for_template = template_generator.get_values_for_tmpl(values)
    >>> query = 'SELECT * FROM (VALUES %(values)s) AS t (name, uid)' % values_template
    >>> session = Session.create_from_uid(uid)
    >>> result_proxy = session.execute(query, values_for_template)
    """

    def __init__(self, columns, expected_values_count=None):
        self._column_names = columns
        self._expected_values_count = expected_values_count
        self._cached_values_template = None

        # предполагается, что шаблон чаще всего будет формироваться определенного размера
        if expected_values_count:
            self._cached_values_template = self._get_values_tmpl(expected_values_count)

    def get_values_row_template(self, index):
        return '(%s)' % ', '.join(':%s_%i' % (c, index) for c in self._column_names)

    def _get_values_tmpl(self, values_count):
        values_tmpl = ', '.join(self.get_values_row_template(i) for i in xrange(values_count))
        return values_tmpl.rstrip(', ')

    def get_values_tmpl(self, values_count):
        """
        Возвращает шаблон для подстановки значений в запрос вида:
        (:a_1, :b_1, :c_1), (:a_2, :b_2, :c_2), (:a_3, :b_3, :c_3)

        :param int values_count: количество значений, для которых нужно вернуть шаблон
        :return str: шаблон для подстановки значений
        """
        if values_count == self._expected_values_count and self._cached_values_template:
            return self._cached_values_template
        return self._get_values_tmpl(values_count)

    def get_values_for_tmpl(self, values):
        """
        Возвращает данные для подстановки в шаблон сгенерированный методом get_values_tmpl

        :param iterable values: iterable of dicts
        :return dict: данные для подстановки в шаблон
        """
        insert_data = {}
        for i, item in enumerate(values):
            for key in self._column_names:
                insert_data["%s_%i" % (key, i)] = item[key]
        return insert_data


class BulkInsertReqGenerator(object):
    def __init__(self, table_schema, items, skip_columns=None, on_conflict_do_nothing=False):
        super(BulkInsertReqGenerator, self).__init__()
        if not items:
            raise ValueError()

        self._table_schema = table_schema
        skip_columns = skip_columns or []
        column_names = []
        for c in self._table_schema.columns:
            if c in skip_columns:
                continue
            column_names.append(c.name)
        column_names.sort()
        self._values_template_generator = ValuesTemplateGenerator(column_names)
        self._items = items
        self._on_conflict = 'ON CONFLICT DO NOTHING' if on_conflict_do_nothing else ''

    def generate_tmpl(self):
        return 'INSERT INTO %(table_name)s (%(columns)s) VALUES %(values)s %(on_conflict)s;' % {
            'table_name': self._table_schema.fullname,
            'columns': ', '.join(self._values_template_generator._column_names),
            'values': self._values_template_generator.get_values_tmpl(len(self._items)),
            'on_conflict': self._on_conflict,
        }

    def generate_values(self):
        return self._values_template_generator.get_values_for_tmpl(item.as_raw_pg_dict() for item in self._items)


class BulkDeleteReqGenerator(object):
    def __init__(self, table_schema, items, reference_columns):
        super(BulkDeleteReqGenerator, self).__init__()
        if not items:
            raise ValueError()

        self._table_schema = table_schema
        self._items = items
        self._values_template_generator = ValuesTemplateGenerator(sorted(reference_columns))

    def generate_tmpl(self):
        pg_clauses = []
        for e, item in enumerate(self._items):
            item_clause = '(%s)' % ' AND '.join([
                '%s = :%s_%s' % (column, column, e)
                for column in self._values_template_generator._column_names
            ])
            pg_clauses.append(item_clause)

        where_clause = ' OR\n'.join(pg_clauses)
        return 'DELETE FROM %(table_name)s\nWHERE %(clauses)s' % {
            'table_name': self._table_schema.fullname,
            'clauses': where_clause,
        }

    def generate_values(self):
        return self._values_template_generator.get_values_for_tmpl(item.as_raw_pg_dict() for item in self._items)


class DAOPath(object):
    """
    Класс для представления пути на уровне абстракции dao. Содержит только путь, начинающийся с / и ничего
    не знает про uid. Умеет отдавать родительский путь и имя ресурса.
    """

    def __init__(self, path):
        if not path.startswith('/'):
            raise ValueError('Path should start with / but it does\'t: %s' % path)
        self._path = path

    def __eq__(self, other):
        return self.get_value() == other.get_value()

    def get_value(self):
        return self._path

    def get_parent_path(self):
        return DAOPath(parent_for_key(self._path))

    def get_name(self):
        return name_for_key(self._path)
