import json
import logging
import traceback

from . import helpers
from . import params
from . import result
from . import subtasks

import sandbox.common.types.client as ctc
import sandbox.common.types.task as ctt
import sandbox.projects.common.binary_task as cbt
import sandbox.sdk2 as sdk2

logger = logging.getLogger(__name__)


class TaskRequirements(sdk2.Requirements):
    GB = 1024

    cores = 1
    ram = 2 * GB
    disk_space = 8 * GB

    client_tags = ctc.Tag.GENERIC & ctc.Tag.Group.LINUX

    # MULTISLOT
    class Caches(sdk2.Requirements.Caches):
        pass


class TaskContext(sdk2.Context):
    child_task_ids = []


class YaTestParent2(cbt.LastBinaryTaskRelease, sdk2.Task):

    Requirements = TaskRequirements
    Context = TaskContext
    Parameters = params.Parameters

    def on_execute(self):
        self.terminate_statuses = list(ctt.Status.Group.FINISH + ctt.Status.Group.BREAK)
        try:
            with self.memoize_stage.subtask:
                self.start_subtasks()
            self.wait_for_tasks()
        except Exception:
            error = traceback.format_exc()
            self.set_info(error)
            self.process_results(error)
            return  # Do not rethrow an exception
        self.process_results()

    def start_subtasks(self):
        if self.Parameters.debug_task_ids:
            self.Context.child_task_ids = [int(x.strip()) for x in self.Parameters.debug_task_ids.split(';')]
            return

        task = subtasks.get_task(self)
        ctor = sdk2.Task[task.task_type]

        logger.debug("Task constructor: %s", type(task))
        logger.debug("Child task requirements:\n%s", json.dumps(task.requirements, indent=4, sort_keys=True))
        logger.debug("Child task context:\n%s", json.dumps(task.context, indent=4, sort_keys=True))

        self.Context.child_task_ids = self.spawn_tasks(ctor, task)

        result.process_start(self)

        self.wait_subtasks()

    def spawn_tasks(self, ctor, task):
        ids = []
        for i in range(1, helpers.get_subtasks_count(self) + 1):
            inst = ctor(
                self,
                description="{} Run #{}".format(self.Parameters.description, i),
                owner=self.owner,
                tags=self.Parameters.tags,
                __requirements__=task.requirements,
                **task.context
            )
            inst.enqueue()
            ids.append(inst.id)

        return ids

    def wait_for_tasks(self):
        statuses = {tid: self.server.task[tid].read()['status'] for tid in self.Context.child_task_ids}
        logger.debug("Child statuses: %s", statuses)

        # Task might be rescheduled - check statuses before processing
        if set(statuses.values()) - set(self.terminate_statuses):
            self.wait_subtasks()

    def process_results(self, error=None):
        result.process_results(self, error)

    def wait_subtasks(self):
        raise sdk2.WaitTask(
            tasks=self.Context.child_task_ids,
            statuses=self.terminate_statuses,
            wait_all=True,
        )
