import os
import subprocess
import tempfile

import sandbox.common.types.client as ctc
import sandbox.common.types.task as ctt
from sandbox.sandboxsdk import environments

from sandbox import sdk2


class RunCommandInContainerBatch(sdk2.Task):
    """
    Task spawns several RUN_COMMAND_IN_CONTAINER children,
    waits for them and then executes some command itself.
    """

    class Context(sdk2.Task.Context):
        workers = []

    class Requirements(sdk2.Task.Requirements):
        cores = 1

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Task.Parameters):
        command = sdk2.parameters.String("Command to execute", required=True, multiline=True)
        venv_pip = sdk2.parameters.String("Venv pip packages comma-separated list", required=False)
        child_count = sdk2.parameters.Integer("Child task count", required=True)
        child_script = sdk2.parameters.String("Command to execute in child tasks", required=True, multiline=True)
        child_container = sdk2.parameters.Integer("Container for child tasks", required=True)
        child_disk_space = sdk2.parameters.Integer("Disk space for child tasks", required=True, default=32212254720)  # 30GB
        child_ram = sdk2.parameters.Integer("Ram for child tasks", required=True, default=12884901888)  # 12GB
        child_ncpu = sdk2.parameters.Integer("NCpu for child tasks", required=True, default=16)
        child_kill_timeout = sdk2.parameters.Integer("Kill_timeout for child tasks", required=True, default=3 * 3600)  # 3 hours
        child_client_tags = sdk2.parameters.String("Client_tags for child tasks", required=True, default='GENERIC & INTEL_E5_2650')
        child_privileged = sdk2.parameters.Bool("Privileged flag for child tasks", required=True, default=True)
        vault_env = sdk2.parameters.Dict("Vault items to put in the environment")

        disk_space = sdk2.parameters.Integer(
            "Required disk space", default=322122547  # 0.3GB
        )
        ram = sdk2.parameters.Integer(
            "Required RAM size", default=1288490188  # 1.2GB
        )
        cores = sdk2.parameters.Integer("Required CPU cores", default=1)
        client_tags = sdk2.parameters.ClientTags(
            "Client tags",
            default=ctc.Tag.GENERIC
        )

    def on_save(self):
        self.Requirements.client_tags = self.Parameters.client_tags

    def _create_child_tasks(self):
        task_type = sdk2.Task["RUN_COMMAND_IN_CONTAINER"]
        tasks = []
        for x in xrange(self.Parameters.child_count):
            task = task_type(
                task_type.current,
                description="Child task #" + str(x),
                priority=ctt.Priority(ctt.Priority.Class.SERVICE, ctt.Priority.Subclass.NORMAL),
                script=self.Parameters.child_script,
                container=self.Parameters.child_container,
                kill_timeout=self.Parameters.child_kill_timeout
            )
            sdk2.Task.server.task[task.id].update(
                requirements={
                    "disk_space": self.Parameters.child_disk_space,
                    'ram': self.Parameters.child_ram,
                    'ncpu': self.Parameters.child_ncpu,
                    'client_tags': self.Parameters.child_client_tags,
                    'privileged': self.Parameters.child_privileged
                }
            )

            tasks.append(task)
            self.Context.workers.append(task.id)
        for task in tasks:
            task.enqueue()
        return tasks

    def on_execute(self):
        command_env = os.environ.copy()
        for vault_item, env_name in self.Parameters.vault_env.items():
            command_env[env_name] = sdk2.Vault.data(vault_item)

        with self.memoize_stage.create_children:
            subtasks = self._create_child_tasks()
            raise sdk2.WaitTask(subtasks, ctt.Status.Group.FINISH | ctt.Status.Group.BREAK, wait_all=True)

        with environments.VirtualEnvironment() as venv:
            command_env["VENV"] = venv.executable
            if self.Parameters.venv_pip:
                for package in self.Parameters.venv_pip.split(","):
                    venv.pip(package)

            f = tempfile.NamedTemporaryFile(delete=False)

            # File must be closed before calling subprocess.call, so we have
            # to manage its lifetime ourselves and there's no point in using `with'
            try:
                f.write(self.Parameters.command)
                f.close()
                os.chmod(f.name, 0770)

                command = f.name
                for workerid in self.Context.workers:
                    if sdk2.Task[workerid].status in [ctt.Status.SUCCESS]:
                        command = command + " " + str(workerid)

                with sdk2.helpers.ProcessLog(self, logger="command") as pl:
                    subprocess.check_call(
                        command,
                        env=command_env,
                        shell=True,
                        stdout=pl.stdout,
                        stderr=pl.stderr,
                        close_fds=True
                    )
            finally:
                os.remove(f.name)
