import logging
import multiprocessing.pool

from sandbox.common import itertools as common_itertools

from sandbox.services import base
from sandbox.yasandbox.database import mapping

logger = logging.getLogger(__name__)


class TaskTagsChecker(base.SingletonService):
    tick_interval = 600

    MAX_WORKERS = 20
    TAGS_CHUNK_SIZE = 5000

    @staticmethod
    def _worker_proc(tag, hits):
        mapping.TaskTagCache.objects(tag=tag).update(set__hits=hits)

    def tick(self):
        tags = dict(mapping.TaskTagCache.objects(
            read_preference=mapping.ReadPreference.SECONDARY
        ).fast_scalar("tag", "hits"))
        logger.info("There are %s tags to check", len(tags))
        thread_pool = multiprocessing.pool.ThreadPool(self.MAX_WORKERS)
        with mapping.switch_db(mapping.Task, mapping.ReadPreference.SECONDARY) as Task:
            for tags_chunk in common_itertools.chunker(list(tags), self.TAGS_CHUNK_SIZE):
                pipeline = [
                    {"$match": {
                        "tags": {"$in": tags_chunk},
                    }},
                    {"$project": {
                        "_id": 1,
                        "tags": 1,
                    }},
                    {"$unwind": "$tags"},
                    {"$match": {
                        "tags": {"$in": tags_chunk},
                    }},
                    {"$project": {
                        "_id": 1,
                        "tags": 1,
                    }},
                    {"$group": {
                        "_id": {"tags": "$tags"},
                        "hits": {"$sum": 1},
                    }},
                ]
                to_delete = []
                updates = 0
                for item in Task.aggregate(pipeline, allowDiskUse=True):
                    tag = item["_id"]["tags"]
                    prev_hits = tags[tag]
                    hits = item["hits"]
                    if hits == 0:
                        to_delete.append(tag)
                    elif prev_hits != hits:
                        thread_pool.apply_async(self._worker_proc, (tag, hits))
                        updates += 1
                logger.info("Added %s tags to update", updates)
                if to_delete:
                    logger.info("%s tags has no tasks, clearing them from cache", len(to_delete))
                    mapping.TaskTagCache.objects(tag__in=to_delete).delete()
        logger.info("Waiting all updates to complete")
        thread_pool.close()
        thread_pool.join()
