import logging
import collections
import datetime as dt

import concurrent.futures

import sandbox.common.types.misc as ctm

from sandbox import common
from sandbox.services import base
from sandbox.yasandbox import controller
from sandbox.common.types import task as ctt
from sandbox.yasandbox.database import mapping

logger = logging.getLogger(__name__)


class AutoRestartTasks(base.SingletonService):
    """
    The service thread automatically restarts tasks in status "TEMPORARY"
    """

    def __init__(self, *args, **kwargs):
        super(AutoRestartTasks, self).__init__(*args, **kwargs)

        # For backward compatibility with old AutoRestartTasks to avoid various races
        self.zk_name = type(self).__name__

        self.pool = concurrent.futures.ThreadPoolExecutor(max_workers=10)

        self.oauth_token = None
        if self.sandbox_config.server.auth.enabled:
            self.oauth_token = common.utils.read_settings_value_from_file(self.sandbox_config.server.auth.oauth.token)

    @property
    def tick_interval(self):
        return self.service_config["run_interval"]

    @staticmethod
    def _restart_target(oauth_token, id_):
        api = common.rest.ThreadLocalCachableClient(auth=oauth_token)

        try:
            results = api.batch.tasks.start.update(id=[id_], comment="Autorestart from TEMPORARY")

            for result in results:
                if result["status"] == ctm.BatchResultStatus.ERROR:
                    outcome = "failed to restart"
                    level_name = result["status"]
                else:  # "WARNING" or "ERROR"
                    outcome = "has been restarted"
                    level_name = "DEBUG"

                level = logging.getLevelName(level_name)
                logger.log(level, "Task #%s %s: %s", result["id"], outcome, result["message"])

        except common.rest.Client.HTTPError as exc:
            logger.error("Failed to restart task #%s: %s", id_, exc)

    def restart(self, id_):
        return self.pool.submit(self._restart_target, self.oauth_token, id_)

    def tick(self):
        now = dt.datetime.utcnow()
        counters = collections.Counter()

        temporary_tasks = (
            mapping.Task.objects(execution__status=ctt.Status.TEMPORARY)
            .fast_scalar(
                "id",
                "execution__auto_restart__left",
                "execution__auto_restart__interval",
                "execution__time__finished",
                "time__updated",
                "execution__time__started",
            )
        )

        to_restart = []

        for (task_id, left, interval, finished, updated, started) in temporary_tasks:

            delta = dt.timedelta(seconds=interval or self.tick_interval)

            if not (left is None or left <= 0 or (left > 0 and ((finished or updated or started) + delta) <= now)):
                continue

            if left is None or left > 0:
                oauth = mapping.OAuthCache.objects(app_id=str(task_id)).first()
                if oauth and controller.OAuthCache.is_valid_token(oauth):
                    logger.warning("Task #%s still have valid session - skipping restart", task_id)
                    continue

            to_restart.append(task_id)
            counters["restart attempted"] += 1

        futures = [self.restart(id_) for id_ in to_restart]
        concurrent.futures.wait(futures)

        totally_checked = len(temporary_tasks)

        counters["skipped"] = totally_checked - sum(counters.itervalues())

        logger.info(
            "In total %d TEMPORARY task(s) checked, %s", totally_checked,
            ", ".join("{} - {}".format(v, k) for k, v in counters.iteritems())
        )

    def on_stop(self):
        super(AutoRestartTasks, self).on_stop()
        self.pool.shutdown(wait=True)
