import logging
import datetime as dt

import concurrent.futures

from sandbox import common

from sandbox.services import base
from sandbox.yasandbox import controller
from sandbox.yasandbox.database import mapping

logger = logging.getLogger(__name__)


class HostsCacheUpdater(base.SingletonService):
    """ Actualize host lists matching tags expressions on a periodic basis """

    CACHE_ITEM_TTL = 48  # in hours
    REST_API_MIN_TIMEOUT = 180  # in seconds
    REST_API_TOTAL_TIMEOUT = 360  # in seconds
    TASKS_HOSTS_CHUNK_SIZE = 20
    # TODO: increase after SANDBOX-8613
    MAX_WORKERS = 16  # max number of threads for calculating hosts lists for tasks

    tick_interval = 600
    notification_timeout = 80

    def __init__(self, *args, **kws):
        super(HostsCacheUpdater, self).__init__(*args, **kws)
        oauth_token = (
            common.utils.read_settings_value_from_file(common.config.Registry().server.auth.oauth.token)
            if common.config.Registry().server.auth.enabled else
            None
        )
        self.api = common.rest.ThreadLocalCachableClient(
            auth=oauth_token, min_timeout=self.REST_API_MIN_TIMEOUT, total_wait=self.REST_API_TOTAL_TIMEOUT
        )
        self.__workers = concurrent.futures.ThreadPoolExecutor(max_workers=self.MAX_WORKERS)

    def update_tasks_hosts(self, tids, cache):
        tasks = mapping.Task.objects(id__in=tids)
        to_sync = []
        for task in tasks:
            hosts = controller.TaskQueue.task_hosts_list(
                task, cache
            )[1]
            to_sync.append((task.id, None, controller.TaskQueue.compress_hosts(hosts), None, task.score or 0))

        controller.TaskQueue.qclient.sync(to_sync)

    def tick(self):
        with common.utils.Timer() as timer:
            with timer["removing"]:
                logger.info("Removing outdated items from cache")
                mapping.ClientTagsToHostsCache.objects(
                    accessed__lt=dt.datetime.utcnow() - dt.timedelta(hours=self.CACHE_ITEM_TTL)
                ).delete()
            tasks_hosts_futures = []
            with timer["updating"] as timer_updating:
                with timer_updating["fetching"]:
                    objs = list(mapping.ClientTagsToHostsCache.objects())
                logger.debug("Updating %s item(s)", len(objs))
                with timer_updating["queue"]:
                    queue = controller.TaskQueue.qclient.queue(secondary=True)
                cache = controller.TaskQueue.HostsCache()
                global_refresh_queue_tids = set()

                for obj in objs:
                    with timer_updating["hosts"]:
                        hosts = cache.client_tags.hosts(obj.client_tags)
                        if set(hosts) == set(obj.hosts):
                            continue
                        logger.debug(
                            "Hosts for %s changed: added %s, removed %s",
                            obj.client_tags, sorted(set(hosts) - set(obj.hosts)), sorted(set(obj.hosts) - set(hosts))
                        )
                        if hosts:
                            obj.hosts = hosts
                            obj.save()
                        else:
                            logger.debug("No hosts for %s, removing from cache", obj.client_tags)
                            obj.delete()
                        refresh_queue_tids = {
                            item.task_id
                            for item in queue
                            if item.task_info.client_tags is not None and item.task_info.client_tags == obj.client_tags
                        }
                        if not refresh_queue_tids:
                            continue
                        refresh_queue_tids -= refresh_queue_tids & global_refresh_queue_tids
                        for tids in common.utils.chunker(list(refresh_queue_tids), self.TASKS_HOSTS_CHUNK_SIZE):
                            logger.info("Tasks in queue to update: %r", tids)
                            tasks_hosts_futures.append(self.__workers.submit(self.update_tasks_hosts, tids, cache))
                        global_refresh_queue_tids.update(refresh_queue_tids)

            if tasks_hosts_futures:
                with timer["tasks_hosts"]:
                    logger.info("Waiting calculating of tasks hosts to complete")
                    concurrent.futures.wait(tasks_hosts_futures)
        logger.info("Queue updated in %s", timer)

    def on_stop(self):
        self.__workers.shutdown(wait=True)
