import time
import logging

from sandbox.common.types import task as ctt

from sandbox import sdk2


class Hydra(sdk2.Task):
    """ Task that creates many child tasks to acquire given semaphore completely """

    class Requirements(sdk2.Requirements):
        cores = 1

        class Caches(sdk2.Requirements.Caches):
            pass  # no shared caches

    class Parameters(sdk2.Parameters):
        semaphore_name = sdk2.parameters.String("Semaphore name", default="hydra_semaphore")
        semaphore_capacity = sdk2.parameters.Integer("Initial semaphore capacity", default=10)
        sleep_seconds = sdk2.parameters.Integer("Time to sleep in seconds", default=60 * 20)

    def on_enqueue(self):
        self.Requirements.semaphores = ctt.Semaphores(
            acquires=[
                ctt.Semaphores.Acquire(name=self.Parameters.semaphore_name, capacity=self.Parameters.semaphore_capacity)
            ]
        )

    def on_execute(self):
        semaphores = self.server.semaphore.read(
            name=self.Parameters.semaphore_name,
            owner=self.owner,
            limit=100,
        )
        logging.info("Semaphores: %s", semaphores)
        semaphore_id = None
        for sem in semaphores["items"]:
            if sem["name"] == self.Parameters.semaphore_name:
                semaphore_id = sem["id"]

        if semaphore_id is None:
            raise ValueError("Semaphore '{}' is not found".format(self.Parameters.semaphore_name))

        contender_cnt = self.server.task.read(
            owner=self.owner, type=Hydra.type, status=ctt.TaskStatus.ENQUEUED, children=True, limit=0,
        )["total"]
        logging.info("Contenders: %s", contender_cnt)

        sem_info = self.server.semaphore[semaphore_id].read()
        logging.info("Semaphore info: %s", sem_info)
        more_tasks = int(sem_info["capacity"] * 1.1) - sem_info["value"] - contender_cnt
        for _ in range(more_tasks):
            task = self.server.task(source=self.id, owner=self.owner, children=True, notifications=[])
            res = self.server.batch.tasks.start.update([task["id"]])
            logging.info("Starting #%s: %s", task["id"], res)
        time.sleep(self.Parameters.sleep_seconds)
