# -*- coding: utf-8 -*-

from pymongo import ReadPreference as MongoReadPreference

from mpfs.core.filesystem.dao.file import StidsListParser
from mpfs.dao.base import BaseDAO, BaseDAOItem, PostgresBaseDAOImplementation, MongoBaseDAOImplementation, \
    convert_mongo_read_preference, QueryWithParams
from mpfs.dao.cursor import PostgresCursor
from mpfs.dao.fields import IntegerAsBoolField, ByteArrayField, ByteStringField, DAOItemField, DateTimeField, HidField, Md5Field, \
                            IntegerField, MediaTypeField, NullableStringField, ResourceTypeField, StringField, UidField
from mpfs.dao.session import Session
from mpfs.metastorage.postgres.queries import SQL_STIDS_BY_STIDS_IN_MISC_DATA, SQL_FILES_BY_STIDS_IN_MISC_DATA
from mpfs.metastorage.postgres.schema import misc_data


class MiscDataDAOItem(BaseDAOItem):
    mongo_collection_name = 'misc_data'
    postgres_table_obj = misc_data
    is_sharded = True

    id = Md5Field(mongo_path='_id', pg_path=misc_data.c.id)
    uid = UidField(mongo_path='uid', pg_path=misc_data.c.uid)
    path = StringField(mongo_path='key', pg_path=misc_data.c.path)
    type = ResourceTypeField(mongo_path='type', pg_path=misc_data.c.type)

    version = IntegerField(mongo_path='version', pg_path=misc_data.c.version, default_value=None)
    parent = Md5Field(mongo_path='parent', pg_path=misc_data.c.parent, default_value=None)
    zdata = ByteArrayField(mongo_path='zdata', pg_path=misc_data.c.zdata, default_value=None)

    file_id = ByteStringField(mongo_path='data.file_id', pg_path=misc_data.c.file_id, default_value=None)
    file_id_zipped = DAOItemField(mongo_path='data.file_id_zipped', pg_path=misc_data.c.file_id_zipped, default_value=None)

    hid = HidField(mongo_path='hid', pg_path=misc_data.c.hid, default_value=None)
    mimetype = NullableStringField(mongo_path='data.mimetype', pg_path=misc_data.c.mimetype, default_value=None)
    mediatype = MediaTypeField(mongo_path='data.mt', pg_path=misc_data.c.mediatype, default_value=None)
    visible = IntegerAsBoolField(mongo_path='data.visible', pg_path=misc_data.c.visible, default_value=None)
    size = IntegerField(mongo_path='data.size', pg_path=misc_data.c.size, default_value=None)

    file_stid = NullableStringField(mongo_path='data.stids', mongo_item_parser=StidsListParser('file_mid'),
                                    pg_path=misc_data.c.file_stid, default_value=None)
    preview_stid = NullableStringField(mongo_path='data.stids', mongo_item_parser=StidsListParser('pmid', is_optional=True),
                                       pg_path=misc_data.c.preview_stid, default_value=None)
    digest_stid = NullableStringField(mongo_path='data.stids', mongo_item_parser=StidsListParser('digest_mid'),
                                      pg_path=misc_data.c.digest_stid, default_value=None)

    date_modified = DateTimeField(mongo_path='data.mtime', pg_path=misc_data.c.date_modified, default_value=None)
    date_uploaded = DateTimeField(mongo_path='data.utime', pg_path=misc_data.c.date_uploaded, default_value=None)
    date_exif = DateTimeField(mongo_path='data.etime', pg_path=misc_data.c.date_exif, default_value=None)

    validation_ignored_mongo_dict_fields = ('zdata',)

    exclude_keys_after_conversion_to_mongo = {
        'version': None,
        'parent': None,
        'zdata': None,
        'data': {
            'file_id': None,
            'file_id_zipped': None,
            'mimetype': None,
            'mt': None,
            'visible': None,
            'size': None,
            'stids': None,
            'mtime': None,
            'utime': None,
            'etime': None,
        },
        'hid': None,
    }


class PostgresMiscDataDAOImplementation(PostgresBaseDAOImplementation):
    def find_stids_on_shard(self, stids, shard_name, limit=None):
        session = Session.create_from_shard_id(shard_name)
        query = QueryWithParams(SQL_STIDS_BY_STIDS_IN_MISC_DATA, {'stids': tuple(stids)})
        return PostgresCursor(session, query, MiscDataDAOItem)

    def find_on_shard(self, spec=None, fields=None, skip=0, limit=0, sort=None, shard_name=None, **kwargs):
        spec_keys = set()
        if spec is not None:
            spec_keys = set(spec.keys())
        if spec_keys == {'uid', 'path', 'key'}:
            spec.pop('path')

        if spec_keys == {'data.stids'} and isinstance(spec['data.stids'], dict) and \
                len(spec['data.stids']) == 1 and spec['data.stids'].keys() == ['$elemMatch'] and \
                spec['data.stids']['$elemMatch'].get('stid', {}).get('$in', None) is not None:
            stids = tuple(spec['data.stids']['$elemMatch']['stid']['$in'])



            read_preference = self.get_read_preference(kwargs)
            if read_preference is not None:
                session = Session.create_from_shard_id(
                    shard_name,
                    read_preference=convert_mongo_read_preference(read_preference),
                )
            else:
                session = Session.create_from_shard_id(shard_name)

            sql_query = SQL_FILES_BY_STIDS_IN_MISC_DATA
            if skip:
                sql_query += ' OFFSET %d' % skip
            if limit:
                sql_query += ' LIMIT %d' % limit

            return PostgresCursor(session, QueryWithParams(sql_query, {'stids': stids}), MiscDataDAOItem)
        else:
            return super(PostgresMiscDataDAOImplementation, self).find_on_shard(
                spec, fields, skip, limit, sort, shard_name, **kwargs
            )


class MongoMiscDataDAOImplementation(MongoBaseDAOImplementation):
    def find_stids_on_shard(self, stids, shard_name, limit=None):
        spec = {
            'data.stids': {
                '$elemMatch': {'stid': {'$in': stids}}
            }
        }
        return self.find_on_shard(spec, {'data.stids': True}, shard_name=shard_name, limit=limit)


class MiscDataDAO(BaseDAO):
    dao_item_cls = MiscDataDAOItem

    def __init__(self):
        super(MiscDataDAO, self).__init__()
        self._pg_impl = PostgresMiscDataDAOImplementation(self.dao_item_cls)
        self._mongo_impl = MongoMiscDataDAOImplementation(self.dao_item_cls)

    def find_stids_on_shard(self, stids, shard_name, limit=None):
        impl = self._get_impl_by_shard(shard_name)
        return impl.find_stids_on_shard(stids, shard_name, limit=limit)
