# -*- coding: utf-8 -*-
import time
import mpfs.engine.process
from mpfs.common.util import chunks2
from mpfs.common.util.rps_limiter import InMemoryRPSLimiter
from mpfs.core.services.queller_service import QuellerTask
from mpfs.core.versioning.logic.version import VersionManager
from mpfs.dao.base import get_all_shard_endpoints
from mpfs.dao.shard_endpoint import ShardEndpoint
from mpfs.core.versioning.dao.version_links import VersionLinkDAO

logger = mpfs.engine.process.get_default_log()


class VersionsCleanerWorker(QuellerTask):
    RAW_SHARD_ENDPOINT = 'raw_shard_endpoint'
    task_timeout = 3500
    rps_limiter = InMemoryRPSLimiter(200)
    batch_size = 10
    get_versions_limit = 1000
    version_links_batch = 10000

    def run(self, *args, **kwargs):
        shard_endpoint = ShardEndpoint.parse(kwargs[self.RAW_SHARD_ENDPOINT])
        start_ts = time.time()
        while True:
            versions = list(VersionManager.fetch_expired_versions_on_shard(shard_endpoint, self.get_versions_limit))

            for versions_batch in chunks2(versions, chunk_size=self.batch_size):
                self.rps_limiter.block_until_allowed(requests_num=len(versions_batch))
                VersionManager.bulk_remove_versions_on_shard(shard_endpoint, versions_batch)

                if time.time() - start_ts > self.task_timeout:
                    logger.warn('Timeout. Exit.')
                    return

            if len(versions) < self.get_versions_limit:
                break
        VersionLinkDAO().delete_version_links_without_versions(shard_endpoint, self.version_links_batch)


    @classmethod
    def put(cls, shard_endpoint):
        if not isinstance(shard_endpoint, ShardEndpoint):
            raise TypeError()
        return cls._put({cls.RAW_SHARD_ENDPOINT: shard_endpoint.serialize()})


class VersionsCleanerManager(object):
    @staticmethod
    def run():
        if VersionsCleanerWorker.not_finished_len() > 0:
            logger.info('Find active worker. Exit.')
            return

        shard_endpoints = get_all_shard_endpoints()
        for shard_endpoint in shard_endpoints:
            VersionsCleanerWorker().put(shard_endpoint)
            logger.info('Create tasks for shard %r' % shard_endpoint)
