# -*- coding: utf-8 -*-
import logging

import sandbox.common.types.task as ctt
from sandbox import sdk2
from sandbox.projects.common import error_handlers as eh


# Mixin to restart child tasks
class ChildTaskRestarter(sdk2.Task):

    class Context(sdk2.Task.Context):
        child_tasks = {}
        run = 1
        max_runs = 3

    def on_execute(self):
        self.run_child_tasks()

    def run_child_tasks(self):
        if not getattr(self.Context, "run_{}".format(self.Context.run)):
            tasks = self.Context.child_tasks
            self._execute_run(self.Context.run, tasks)
        tasks = self._check_run(self.Context.run)
        logging.debug("tasks: %s", tasks)
        if self.Context.run <= self.Context.max_runs and tasks:
            self.Context.run += 1
            self.Context.save()
            self._execute_run(self.Context.run, tasks)

        if tasks:
            eh.check_failed("Failed to restart child tasks. Execute {} runs".format(self.Context.run))

    def _execute_run(self, run_number, tasks):
        setattr(self.Context, "run_{}".format(run_number), {})
        logging.info("Context")
        logging.info(self.Context.run_1)
        run_context = getattr(self.Context, "run_{}".format(run_number))
        for task_name, task_info in tasks.iteritems():
            task_type = task_info.get("task_type", task_name)
            child_task_id = self._create_subtask(task_type, task_name, task_info)
            self.Context.child_tasks[task_name]["task_id"] = child_task_id
            run_context[task_name] = child_task_id
        finish_statuses = tuple(ctt.Status.Group.FINISH) + tuple(ctt.Status.Group.BREAK)
        self.Context.save()
        raise sdk2.WaitTask(run_context.values(), finish_statuses)

    def _check_run(self, run_number):
        finish_statuses = tuple(ctt.Status.Group.FINISH) + tuple(ctt.Status.Group.BREAK)
        run_context = getattr(self.Context, "run_{}".format(run_number))
        logging.debug("run Context")
        logging.debug(run_context)
        broken_child_tasks = {}
        for task, task_id in run_context.iteritems():
            child_task = sdk2.Task[task_id]
            if child_task.status not in finish_statuses:
                self.Context.save()
                raise sdk2.WaitTask(task_id, finish_statuses)
            if child_task.status in ctt.Status.Group.BREAK or child_task.status == ctt.Status.FAILURE:
                broken_child_tasks[task] = self.Context.child_tasks[task]
        return broken_child_tasks

    def _create_subtask_sdk1(self, task_type, description, parameters, requirements=None):
        task = self.server.task({"type": task_type, "context": parameters, "children": True})
        update = {
            "description": description,
            "owner": self.owner,
            "priority": {"class": ctt.Priority.Class.SERVICE, "subclass": ctt.Priority.Subclass.LOW},
        }
        if requirements:
            update["requirements"] = requirements
        self.server.task[task["id"]].update(update)
        self.server.batch.tasks.start.update([task["id"]])
        return task["id"]

    def _create_subtask(self, task_type, task_name, task_info):
        if "sdk2" in task_info and task_info["sdk2"]:
            child_task_class = sdk2.Task[task_type]
            logging.debug(task_info)
            child_task = child_task_class(
                self,
                description="{0}: child task {1} Run #{2}".format(
                    self.Parameters.description, task_name, self.Context.run
                ),
                priority=ctt.Priority(ctt.Priority.Class.SERVICE, ctt.Priority.Subclass.LOW),
                **task_info["params"]
            ).enqueue()
            return child_task.id
        return self._create_subtask_sdk1(
            task_type,
            "{0}: child task {1} Run #{2}".format(self.Parameters.description, task_name, self.Context.run),
            task_info["params"]
        )
