# -*- coding: utf-8 -*-
import datetime
import itertools
from enum import Enum
from pymongo import DESCENDING, ASCENDING, ReadPreference

from mpfs.dao.base import (
    BaseDAOItem,
    Session,
    BulkInsertReqGenerator
)
from mpfs.core.versioning.dao.base import (
    BaseVersioningDAO,
    MongoBaseVersioningDAOImplementation,
    PostgresBaseVersioningDAOImplementation
)
from mpfs.dao.fields import (
    AntiVirusStatusField,
    BoolField,
    EnumField,
    HidField,
    IntegerField,
    JsonField,
    Md5Field,
    Sha256Field,
    StringField,
    UidField,
    UuidField,
    MSKDateTimeField,
    StidField,
    DateTimeField
)
from mpfs.metastorage.postgres.schema import (
    version_data,
    storage_files
)
from mpfs.core.filesystem.dao.file import StidsListParser, FileDAO
from mpfs.common.util import get_first
from mpfs.metastorage.postgres import versioning_queries


class VersionType(Enum):
    current = 'current'  # Текущая бинарная версия, в базе таких быть не должно
    binary = 'binary'  # Бинарная версия
    trashed = 'trashed'  # Фейковая версия "файл был удален"
    restored = 'restored'  # Фейковая версия "файл был восстановлен"


class VersionDataDAOItem(BaseDAOItem):
    mongo_collection_name = 'version_data'
    postgres_table_obj = version_data
    uid_field_name = 'uid'
    is_sharded = True
    columns_map = {c.name: c for c in itertools.chain(version_data.columns, storage_files.columns)}

    id = UuidField(mongo_path='_id', pg_path=version_data.c.id)
    uid = UidField(mongo_path='uid', pg_path=version_data.c.uid, default_value=None)
    record_date_created = MSKDateTimeField(mongo_path='record_date_created', pg_path=version_data.c.record_date_created, default_value=None)
    """Дата создания записи в базе. Для бизнес логики не используется."""
    date_to_remove = MSKDateTimeField(mongo_path='date_to_remove', pg_path=version_data.c.date_to_remove, default_value=None)
    """Дата когда надо удалить версию"""

    version_link_id = UuidField(mongo_path='version_link_id', pg_path=version_data.c.version_link_id, default_value=None)
    parent_version_id = UuidField(mongo_path='parent_version_id', pg_path=version_data.c.parent_version_id, default_value=None)
    """id предыдущей версии (родителя). Ни для чего не используется. Так мы "подстраховываемся" от факапа."""
    type = EnumField(mongo_path='version_type', pg_path=version_data.c.type, enum_class=VersionType,
                     default_value=VersionType.binary)
    """Тип версии."""
    date_created = MSKDateTimeField(mongo_path='date_created', pg_path=version_data.c.date_created)
    """Дата создания версии"""
    is_checkpoint = BoolField(mongo_path='is_checkpoint', pg_path=version_data.c.is_checkpoint, default_value=True)
    """Признак, что версия-чекпойнт (несвёртнутая). Используется в механизме прореживания (схлопования) версий."""
    folded_counter = IntegerField(mongo_path='folded_counter', pg_path=version_data.c.folded_counter, default_value=0)
    """Для версий-чекпойнтов число свёрнутых версий."""
    uid_created = UidField(mongo_path='uid_created', pg_path=version_data.c.uid_created)
    """Пользователь, создавший версий."""
    platform_created = StringField(mongo_path='platform_created', pg_path=version_data.c.platform_created, default_value=None)
    """Платформа породившая версию."""

    # поля, относящиеся в "бинаринику" (storage_file). Они есть только у "обычных" версий
    hid = HidField(mongo_path='storage_file.hid', pg_path=storage_files.c.storage_id, default_value=None)
    size = IntegerField(mongo_path='storage_file.size', pg_path=storage_files.c.size, default_value=None)
    md5 = Md5Field(mongo_path='storage_file.md5', pg_path=storage_files.c.md5_sum, default_value=None)
    sha256 = Sha256Field(mongo_path='storage_file.sha256', pg_path=storage_files.c.sha256_sum, default_value=None)
    file_stid = StidField(mongo_path='storage_file.stids', mongo_item_parser=StidsListParser('file_mid'),
                          pg_path=storage_files.c.stid, default_value=None)
    preview_stid = StidField(mongo_path='storage_file.stids', mongo_item_parser=StidsListParser('pmid', is_optional=True),
                             pg_path=storage_files.c.preview_stid, default_value=None)
    digest_stid = StidField(mongo_path='storage_file.stids', mongo_item_parser=StidsListParser('digest_mid'),
                            pg_path=storage_files.c.digest_stid, default_value=None)
    exif_time = DateTimeField(mongo_path='storage_file.etime', pg_path=version_data.c.date_exif, default_value=None)
    antivirus_status = AntiVirusStatusField(mongo_path='storage_file.drweb', pg_path=storage_files.c.av_scan_status, default_value=None)
    video_info = JsonField(mongo_path='storage_file.video_info', pg_path=storage_files.c.video_data,
                           default_value=None)
    width = IntegerField(mongo_path='storage_file.width', pg_path=storage_files.c.width, default_value=None)
    height = IntegerField(mongo_path='storage_file.height', pg_path=storage_files.c.height, default_value=None)
    angle = IntegerField(mongo_path='storage_file.angle', pg_path=storage_files.c.angle, default_value=None)


class VersionDataDAO(BaseVersioningDAO):
    dao_item_cls = VersionDataDAOItem

    def __init__(self):
        super(VersionDataDAO, self).__init__()
        self._mongo_impl = MongoVersionDataDAOImplementation(self.dao_item_cls)
        self._pg_impl = PostgresVersionDataDAOImplementation(self.dao_item_cls)

    def get_all_ascending(self, uid, version_link_id, offset, limit):
        return self._get_impl_by_uid(uid).get_all_ascending(uid, version_link_id, offset, limit)

    def get_all(self, uid, version_link_id, border_dt, limit):
        return self._get_impl_by_uid(uid).get_all(uid, version_link_id, border_dt, limit)

    def get_checkpoints(self, uid, version_link_id, border_dt, limit):
        return self._get_impl_by_uid(uid).get_checkpoints(uid, version_link_id, border_dt, limit)

    def get_latest_version(self, uid, version_link_id):
        return self._get_impl_by_uid(uid).get_latest_version(uid, version_link_id)

    def get_earliest_version(self, uid, version_link_id):
        return self._get_impl_by_uid(uid).get_earliest_version(uid, version_link_id)

    def count_version_link_versions(self, uid, version_link_id):
        return self._get_impl_by_uid(uid).count_version_link_versions(uid, version_link_id)

    def count_version_link_versions_greater_than_dt(self, uid, version_link_id, border_dt):
        return self._get_impl_by_uid(uid).count_version_link_versions_greater_than_dt(uid, version_link_id, border_dt)

    def reset_checkpoint(self, uid, version_id):
        return self._get_impl_by_uid(uid).reset_checkpoint(uid, version_id)

    def get_by_id(self, uid, version_link_id, version_id):
        return self._get_impl_by_uid(uid).get_by_id(uid, version_link_id, version_id)

    def save(self, item):
        if item.uid is None or item.version_link_id is None:
            raise ValueError("%r" % item)
        return super(VersionDataDAO, self).save(item)

    def bulk_delete(self, uid, items):
        if not items:
            return
        for item in items:
            if item.uid is None or item.version_link_id is None:
                raise ValueError("%r" % item)
            if uid != item.uid:
                raise ValueError("Can delete only for one uid. %s %s" % (uid, item.uid))
        return self._get_impl_by_uid(uid).bulk_delete(uid, items)

    def bulk_insert(self, uid, items):
        if not items:
            return
        for item in items:
            if item.uid is None or item.version_link_id is None:
                raise ValueError("%r" % item)
            if uid != item.uid:
                raise ValueError("Can save only for one uid. %s %s" % (uid, item.uid))
        return self._get_impl_by_uid(uid).bulk_insert(uid, items)

    def fetch_by_stids(self, shard_name, stids, limit):
        if not stids:
            raise StopIteration()
        if limit < 1:
            raise StopIteration()
        impl = self._get_impl_by_shard(shard_name)
        return impl.fetch_by_stids(shard_name, stids, limit)

    def fetch_expired_versions_on_shard(self, shard_endpoint, limit):
        return self._get_impl_by_shard_endpoint(shard_endpoint).fetch_expired_versions_on_shard(shard_endpoint, limit)

    def bulk_delete_on_shard(self, shard_endpoint, items):
        if not items:
            return
        return self._get_impl_by_shard_endpoint(shard_endpoint).bulk_delete_on_shard(shard_endpoint, items)


class MongoVersionDataDAOImplementation(MongoBaseVersioningDAOImplementation):
    def get_all_ascending(self, uid, version_link_id, offset, limit):
        spec = {
            'uid': uid,
            'version_link_id': version_link_id,
        }
        coll = self.get_collection_by_uid(uid)
        cursor = coll.find(spec, skip=offset, limit=limit)
        cursor.sort([('date_created', ASCENDING)])
        return [self.doc_to_item(d) for d in cursor]

    def get_all(self, uid, version_link_id, border_dt, limit):
        return self._get_versions(uid, version_link_id, border_dt, limit, False)

    def get_checkpoints(self, uid, version_link_id, border_dt, limit):
        return self._get_versions(uid, version_link_id, border_dt, limit, True)

    def get_earliest_version(self, uid, version_link_id):
        result = self._get_versions(uid, version_link_id, None, 1, False, date_created_sort_order=ASCENDING)
        return get_first(result)

    def get_latest_version(self, uid, version_link_id):
        result = self._get_versions(uid, version_link_id, None, 1, False)
        return get_first(result)

    def get_by_id(self, uid, version_link_id, version_id):
        coll = self.get_collection_by_uid(uid)
        doc = coll.find_one({
            'uid': uid,
            '_id': version_id,
            'version_link_id': version_link_id,
        })
        return self.doc_to_item(doc)

    def count_version_link_versions(self, uid, version_link_id):
        coll = self.get_collection_by_uid(uid)
        return coll.find({
            'uid': uid,
            'version_link_id': version_link_id,
        }).count()

    def count_version_link_versions_greater_than_dt(self, uid, version_link_id, border_dt):
        coll = self.get_collection_by_uid(uid)
        return coll.find({
            'uid': uid,
            'version_link_id': version_link_id,
            'date_created': {'$gt': border_dt},
        }).count()

    def reset_checkpoint(self, uid, version_id):
        coll = self.get_collection_by_uid(uid)
        coll.update({'uid': uid, '_id': version_id}, {'$set': {'is_checkpoint': False}})

    def _get_versions(self, uid, version_link_id, border_dt, limit, only_checkpoints, date_created_sort_order=DESCENDING):
        spec = {
            'uid': uid,
            'version_link_id': version_link_id,
        }
        if border_dt:
            spec['date_created'] = {'$lt': border_dt}
        if only_checkpoints:
            spec['is_checkpoint'] = True
        if limit is None:
            limit = 0
        coll = self.get_collection_by_uid(uid)
        cursor = coll.find(spec, limit=limit)
        cursor.sort([('date_created', date_created_sort_order)])
        return [self.doc_to_item(d) for d in cursor]

    def bulk_delete(self, uid, items):
        coll = self.get_collection_by_uid(uid)
        coll.remove({'uid': uid, '_id': {'$in': [i.id for i in items]}})

    def fetch_by_stids(self, shard_name, stids, limit):
        coll = self.get_collection_by_shard_name(shard_name)
        cursor = coll.find({
            'storage_file.stids': {'$elemMatch': {'stid': {'$in': stids}}},
        }, limit=limit)
        for doc in cursor:
            yield self.doc_to_item(doc)

    def fetch_expired_versions_on_shard(self, shard_endpoint, limit):
        border_dt = datetime.datetime.now()
        query = {'date_to_remove': {'$lte': border_dt}}
        coll = self.get_collection_by_shard_endpoint(shard_endpoint)
        for doc in coll.find(query, limit=limit, read_preference=ReadPreference.PRIMARY_PREFERRED):
            yield self.doc_to_item(doc)

    def bulk_delete_on_shard(self, shard_endpoint, items):
        coll = self.get_collection_by_shard_endpoint(shard_endpoint)
        coll.remove({'_id': {'$in': [i.id for i in items]}})


class PostgresVersionDataDAOImplementation(PostgresBaseVersioningDAOImplementation):
    def get_all_ascending(self, uid, version_link_id, offset, limit):
        uid = self.get_field_repr('uid', uid)
        version_link_id = self.get_field_repr('version_link_id', version_link_id)

        session = Session.create_from_uid(uid)
        result_proxy = session.execute(
            versioning_queries.SQL_VERSIONING_GET_ALL_VERSIONS_ASC,
            {'uid': uid, 'version_link_id': version_link_id, 'limit': limit, 'offset': offset}
        )
        return [self.doc_to_item(i) for i in result_proxy]

    def get_all(self, uid, version_link_id, border_dt, limit):
        return self._get_versions(uid, version_link_id, border_dt, limit, only_checkpoints=False)

    def get_checkpoints(self, uid, version_link_id, border_dt, limit):
        return self._get_versions(uid, version_link_id, border_dt, limit, only_checkpoints=True)

    def get_latest_version(self, uid, version_link_id):
        result = self._get_versions(uid, version_link_id, None, 1, only_checkpoints=False)
        return get_first(result)

    def get_earliest_version(self, uid, version_link_id):
        uid = self.get_field_repr('uid', uid)
        version_link_id = self.get_field_repr('version_link_id', version_link_id)

        session = Session.create_from_uid(uid)
        result_proxy = session.execute(
            versioning_queries.SQL_VERSIONING_GET_ALL_VERSIONS_ASC,
            {'uid': uid, 'version_link_id': version_link_id, 'limit': 1, 'offset': 0}
        )
        return self.fetch_one_item(result_proxy)

    def get_by_id(self, uid, version_link_id, version_id):
        uid = self.get_field_repr('uid', uid)
        version_link_id = self.get_field_repr('version_link_id', version_link_id)
        version_id = self.get_field_repr('id', version_id)

        session = Session.create_from_uid(uid)
        result_proxy = session.execute(
            versioning_queries.SQL_VERSIONING_GET_VERSION_BY_ID,
            {'uid': uid, 'version_link_id': version_link_id, 'id': version_id}
        )
        return self.fetch_one_item(result_proxy)

    def reset_checkpoint(self, uid, version_id):
        uid = self.get_field_repr('uid', uid)
        version_id = self.get_field_repr('id', version_id)

        session = Session.create_from_uid(uid)
        session.execute(
            versioning_queries.SQL_VERSIONING_SET_CHECKPOINT,
            {'uid': uid, 'is_checkpoint': False, 'id': version_id}
        )

    def save(self, item):
        params = {c.name: v for c, v in item.get_postgres_representation().iteritems()}
        session = Session.create_from_uid(params['uid'])
        with session.begin():
            session.execute(versioning_queries.SQL_VERSIONING_SAVE_VERSION, params)
            #TODO save storage_file

    def bulk_insert(self, uid, items):
        uid = self.get_field_repr('uid', uid)
        session = Session.create_from_uid(uid)

        version_data_gen = BulkInsertReqGenerator(version_data, items)
        items_with_storage_files = [i for i in items if i.hid is not None]
        storage_files_gen = None
        if items_with_storage_files:
            storage_files_gen = BulkInsertReqGenerator(
                storage_files,
                items_with_storage_files,
                skip_columns=[storage_files.c.date_origin],
                on_conflict_do_nothing=True
            )

        with session.begin():
            session.execute(
                version_data_gen.generate_tmpl(),
                version_data_gen.generate_values()
            )
            if storage_files_gen:
                session.execute(
                    storage_files_gen.generate_tmpl(),
                    storage_files_gen.generate_values()
                )

    def count_version_link_versions(self, uid, version_link_id):
        uid = self.get_field_repr('uid', uid)
        params = {
            'uid': uid,
            'version_link_id': self.get_field_repr('version_link_id', version_link_id),
        }
        session = Session.create_from_uid(uid)
        result_proxy = session.execute(versioning_queries.SQL_VERSIONING_COUNT_VERSION_LINK_VERSIONS, params)
        return result_proxy.fetchone()[0]

    def count_version_link_versions_greater_than_dt(self, uid, version_link_id, border_dt):
        uid = self.get_field_repr('uid', uid)
        params = {
            'uid': uid,
            'version_link_id': self.get_field_repr('version_link_id', version_link_id),
            'date_created': self.get_field_repr('date_created', border_dt),
        }
        session = Session.create_from_uid(uid)
        result_proxy = session.execute(versioning_queries.SQL_VERSIONING_COUNT_VERSION_LINK_VERSIONS_GREATER_THAN_DT, params)
        return result_proxy.fetchone()[0]

    def count_by_uid(self, uid):
        uid = self.get_field_repr('uid', uid)
        session = Session.create_from_uid(uid)
        result_proxy = session.execute(versioning_queries.SQL_VERSIONING_COUNT_VERSIONS_BY_UID, {'uid': uid})
        return result_proxy.fetchone()[0]

    def remove_by_uid(self, uid):
        uid = self.get_field_repr('uid', uid)
        session = Session.create_from_uid(uid)
        session.execute(versioning_queries.SQL_VERSIONING_DELETE_VERSIONS_BY_UID, {'uid': uid})

    def fetch_by_uid(self, uid):
        uid = self.get_field_repr('uid', uid)
        session = Session.create_from_uid(uid)
        result_proxy = session.execute(versioning_queries.SQL_VERSIONING_GET_VERSIONS_BY_UID, {'uid': uid})
        for doc in result_proxy:
            yield self.doc_to_item(doc)

    def _get_versions(self, uid, version_link_id, border_dt, limit, only_checkpoints=False, order='DESC'):
        uid = self.get_field_repr('uid', uid)
        version_link_id = self.get_field_repr('version_link_id', version_link_id)
        only_checkpoints = bool(only_checkpoints)
        assert order in ('DESC', 'ASC')

        params = {'uid': uid, 'version_link_id': version_link_id, 'order': order, 'limit': limit}
        if border_dt:
            params['date_created'] = self.get_field_repr('date_created', border_dt)

        if border_dt and only_checkpoints:
            query = versioning_queries.SQL_VERSIONING_GET_CHECKPOINT_VERSIONS_WITH_BORDER
        elif border_dt and not only_checkpoints:
            query = versioning_queries.SQL_VERSIONING_GET_ALL_VERSIONS_WITH_BORDER
        elif not border_dt and only_checkpoints:
            query = versioning_queries.SQL_VERSIONING_GET_CHECKPOINT_VERSIONS
        elif not border_dt and not only_checkpoints:
            query = versioning_queries.SQL_VERSIONING_GET_ALL_VERSIONS

        session = Session.create_from_uid(uid)
        result_proxy = session.execute(query, params)
        return [self.doc_to_item(i) for i in result_proxy]

    def fetch_by_stids(self, shard_name, stids, limit):
        session = Session.create_from_shard_id(shard_name)
        result_proxy = session.execute(
            versioning_queries.SQL_VERSIONING_GET_ALL_VERSIONS_BY_STIDS,
            {'stids': tuple(stids), 'limit': limit}
        )
        for doc in result_proxy:
            yield self.doc_to_item(doc)

    def fetch_expired_versions_on_shard(self, shard_endpoint, limit):
        border_dt = datetime.datetime.now()
        session = Session.create_from_shard_endpoint(shard_endpoint)
        result_proxy = session.execute(
            versioning_queries.SQL_VERSIONING_GET_EXPIRED_VERSIONS,
            {'date_to_remove': border_dt, 'limit': limit}
        )
        for doc in result_proxy:
            yield self.doc_to_item(doc)

    def bulk_delete(self, uid, items):
        uid = self.get_field_repr('uid', uid)
        session = Session.create_from_uid(uid)
        self._bulk_delete(session, items)

    def bulk_delete_on_shard(self, shard_endpoint, items):
        session = Session.create_from_shard_endpoint(shard_endpoint)
        self._bulk_delete(session, items)

    def _bulk_delete(self, session, items):
        storage_file_ids = tuple(self.get_field_repr('hid', i.hid) for i in items if i.hid is not None)
        with session.begin():
            session.execute(
                versioning_queries.SQL_VERSIONING_DELETE_VERSIONS,
                {'ids': tuple(self.get_field_repr('id', i.id) for i in items)}
            )
            if storage_file_ids:
                FileDAO(session).remove_hanging_storage_files(storage_file_ids)
