# -*- coding: utf-8 -*-
import mpfs.engine.process
from pymongo import DESCENDING

from mpfs.dao.base import (
    BaseDAO, BaseDAOItem, MongoBaseDAOImplementation,
    PostgresBaseDAOImplementation, MongoHelper, BulkInsertReqGenerator
)
from mpfs.dao.fields import DateTimeField, UidField, FileIdField, UuidField
from mpfs.dao.session import Session
from mpfs.metastorage.postgres.schema import last_files_cache
from mpfs.metastorage.postgres.services import Sharpei
from mpfs.metastorage.postgres.queries import (
    SQL_GET_LAST_FILES_CACHE, SQL_DELETE_CACHE_BY_GID,
    SQL_DELETE_CACHE_BY_GID_AND_UID, SQL_DELETE_CACHE_BY_GID_AND_IGNORE_UIDS,
    SQL_GET_CACHE_BY_GID_AND_UID, SQL_DELETE_CACHE_BY_IDS)


class LastFilesCacheDAOItem(BaseDAOItem):
    mongo_collection_name = 'last_files_cache'
    postgres_table_obj = last_files_cache
    uid_field_name = 'uid'
    is_sharded = True

    id = UuidField(mongo_path='_id', pg_path=last_files_cache.c.id)
    uid = UidField(mongo_path='uid', pg_path=last_files_cache.c.uid)
    owner_uid = UidField(mongo_path='owner_uid', pg_path=last_files_cache.c.owner_uid)
    gid = UuidField(mongo_path='gid', pg_path=last_files_cache.c.gid)
    file_id = FileIdField(mongo_path='file_id', pg_path=last_files_cache.c.file_id)
    creation_time = DateTimeField(mongo_path='creation_time', pg_path=last_files_cache.c.date_created)
    file_date_modified = DateTimeField(mongo_path='file_date_modified', pg_path=last_files_cache.c.file_date_modified)


class LastFilesCacheDAO(BaseDAO):
    dao_item_cls = LastFilesCacheDAOItem

    def __init__(self):
        super(LastFilesCacheDAO, self).__init__()
        self._mongo_impl = MongoLastFilesCacheDAOImplementation(self.dao_item_cls)
        self._pg_impl = PostgresLastFilesCacheDAOImplementation(self.dao_item_cls)

    def get(self, uid, limit=10):
        return self._get_impl_by_uid(uid).get(uid, limit=limit)

    def get_by_gid_uid(self, gid, uid):
        return self._get_impl_by_uid(uid).get_by_gid_uid(gid, uid)

    def set(self, uid, items):
        if not items:
            return
        for item in items:
            if not isinstance(item, self.dao_item_cls):
                raise TypeError("Expect %r, got item: %r" % (self.dao_item_cls, item))
            if item.uid != uid:
                raise ValueError('Uids mismatch. Uid: %s, item.uid: %s, item %s' % (uid, item.uid, item))
        return self._get_impl_by_uid(uid).set(uid, items)

    def drop_by_gid_uid(self, gid, uid):
        return self._get_impl_by_uid(uid).drop_by_gid_uid(gid, uid)

    def drop(self, gid, ignore_uids=None):
        if not ignore_uids:
            ignore_uids = []
        elif not isinstance(ignore_uids, (list, tuple)):
            raise TypeError()
        self._mongo_impl.drop(gid, ignore_uids=ignore_uids)
        self._pg_impl.drop(gid, ignore_uids=ignore_uids)

    def drop_by_uid_ids(self, uid, ids=None):
        if not ids:
            return
        if not isinstance(ids, (list, tuple)):
            raise TypeError()
        return self._get_impl_by_uid(uid).drop_by_uid_ids(uid, ids)


class PostgresLastFilesCacheDAOImplementation(PostgresBaseDAOImplementation):
    def get(self, uid, limit=10):
        uid = self.dao_item_cls.get_field_pg_representation('uid', uid)
        session = Session.create_from_uid(uid)
        cursor = session.execute(SQL_GET_LAST_FILES_CACHE, {'uid': uid, 'limit': limit})
        return [self.doc_to_item(i) for i in cursor]

    def get_by_gid_uid(self, gid, uid):
        gid = self.dao_item_cls.get_field_pg_representation('gid', gid)
        uid = self.dao_item_cls.get_field_pg_representation('uid', uid)
        session = Session.create_from_uid(uid)
        cursor = session.execute(SQL_GET_CACHE_BY_GID_AND_UID, {'gid': gid, 'uid': uid})
        return [self.doc_to_item(i) for i in cursor]

    def set(self, uid, items):
        uid = self.dao_item_cls.get_field_pg_representation('uid', uid)
        gen = BulkInsertReqGenerator(self.dao_item_cls.postgres_table_obj, items)
        session = Session.create_from_uid(uid)
        session.execute(gen.generate_tmpl(), gen.generate_values())

    def drop_by_gid_uid(self, gid, uid):
        gid = self.dao_item_cls.get_field_pg_representation('gid', gid)
        uid = self.dao_item_cls.get_field_pg_representation('uid', uid)
        session = Session.create_from_uid(uid)
        session.execute(SQL_DELETE_CACHE_BY_GID_AND_UID, {'gid': gid, 'uid': uid})

    def drop_by_uid_ids(self, uid, ids=None):
        uid = self.dao_item_cls.get_field_pg_representation('uid', uid)
        ids = tuple(self.dao_item_cls.get_field_pg_representation('id', i) for i in ids)
        session = Session.create_from_uid(uid)
        session.execute(SQL_DELETE_CACHE_BY_IDS, {'ids': ids})

    def drop(self, gid, ignore_uids=None):
        gid = self.dao_item_cls.get_field_pg_representation('gid', gid)
        if ignore_uids:
            sql_tmpl = SQL_DELETE_CACHE_BY_GID_AND_IGNORE_UIDS
            params = {
                'gid': gid,
                'ignore_uids': tuple(self.dao_item_cls.get_field_pg_representation('uid', u) for u in ignore_uids)
            }
        else:
            sql_tmpl = SQL_DELETE_CACHE_BY_GID
            params = {
                'gid': gid,
            }
        for shard_id in Sharpei().get_all_shard_ids():
            session = Session.create_from_shard_id(shard_id)
            session.execute(sql_tmpl, params)


class MongoLastFilesCacheDAOImplementation(MongoBaseDAOImplementation):
    def get(self, uid, limit=10):
        uid = self.dao_item_cls.get_field_mongo_representation('uid', uid)
        cursor = self.get_collection_by_uid(uid).find({'uid': uid}, limit=limit)
        cursor.sort([('file_date_modified', DESCENDING)])
        return [self.doc_to_item(d) for d in cursor]

    def get_by_gid_uid(self, gid, uid):
        uid = self.dao_item_cls.get_field_mongo_representation('uid', uid)
        gid = self.dao_item_cls.get_field_mongo_representation('gid', gid)
        cursor = self.get_collection_by_uid(uid).find({'uid': uid, 'gid': gid})
        return [self.doc_to_item(d) for d in cursor]

    def set(self, uid, items):
        insert_data = [i.get_mongo_representation() for i in items]
        self.get_collection_by_uid(uid).insert(insert_data)

    def drop_by_gid_uid(self, gid, uid):
        uid = self.dao_item_cls.get_field_mongo_representation('uid', uid)
        gid = self.dao_item_cls.get_field_mongo_representation('gid', gid)
        self.get_collection_by_uid(uid).remove({'gid': gid, 'uid': uid})

    def drop_by_uid_ids(self, uid, ids=None):
        uid = self.dao_item_cls.get_field_mongo_representation('uid', uid)
        ids = [self.dao_item_cls.get_field_mongo_representation('id', i) for i in ids]
        self.get_collection_by_uid(uid).remove({'_id': {'$in': ids}})

    def drop(self, gid, ignore_uids=None):
        gid = self.dao_item_cls.get_field_mongo_representation('gid', gid)
        if ignore_uids:
            ignore_uids = [self.dao_item_cls.get_field_mongo_representation('uid', u) for u in ignore_uids]
            spec = {
                'gid': gid,
                'uid': {'$nin': ignore_uids},
            }
        else:
            spec = {
                'gid': gid,
            }
        for routed_coll in MongoHelper().iter_over_all_shards(self.dao_item_cls.mongo_collection_name):
            routed_coll.remove(spec)
