from functools import wraps
from itertools import ifilter
from sandbox import sdk2
import sandbox.common.errors as errors
from sandbox.common.types.task import Status


DEFAULT_TIMEOUT = 3600  # 1 hour


def is_not_succeed(task):
    return task.status not in Status.Group.SUCCEED


def get_ids(tasks):
    return [task.id for task in tasks]


class SubtasksError(errors.TaskError):
    def _init__(self, stage_name, failed_ids):
        message = "Stage {0} failed: {1}".format(stage_name, failed_ids)
        super(SubtasksError, self).__init__(message)


class ParentTask(sdk2.Task):
    """Helper base class for task with subtasks
    Usage example:

    class MyTask(ParentTask)
        @ParentTask.subtasks()
        def run_multiple_subtasks(self):
            return [
                SomeTask1(self, description),
                SomeTask2(self, description)
            ]

        @ParentTask.subtask()
        def run_single_subtask(self, resource1, resource2):
            return SomeTask3(self, description, resource1, resource2)

        def on_execute():
            task1, task2 = self.run_multiple_subtasks()
            task3 = self.run_single_subtask(
                task1.Parameters.output_resource,
                task2.Parameters.output_resource)
    """

    class Context(sdk2.Task.Context):
        saved_subtasks = {}

    def _find_subtasks(self, ids):
        return [self.find(id=id).first() for id in ids]

    def _do_run(self, subtasks, timeout):
        for task in subtasks:
            task.enqueue()
        raise sdk2.WaitTask(
            subtasks,
            Status.Group.FINISH + Status.Group.BREAK,
            wait_all=True,
            timeout=timeout)

    @staticmethod
    def check_status(stage_name, subtasks):
        failed_ids = get_ids(ifilter(is_not_succeed, subtasks))
        if failed_ids:
            raise SubtasksError("Stage {0} failed: {1}".format(
                stage_name, failed_ids))

    def run_subtasks(self, stage_name, tasks_factory, timeout=DEFAULT_TIMEOUT):
        subtask_ids = self.Context.saved_subtasks.get(stage_name)
        if subtask_ids is None:
            subtasks = tasks_factory()
            ids = get_ids(subtasks)
            self.Context.saved_subtasks[stage_name] = ids
            self.Context.save()
            self.set_info("Starting stage {0}: {1}".format(stage_name, ids))
            self._do_run(subtasks, timeout)
        subtasks = self._find_subtasks(subtask_ids)
        self.check_status(stage_name, subtasks)
        return subtasks

    def run_subtask(self, stage_name, task_factory, timeout=DEFAULT_TIMEOUT):
        return self.run_subtasks(
            stage_name,
            lambda: [task_factory()],
            timeout
        )[0]

    @staticmethod
    def subtask(timeout=DEFAULT_TIMEOUT):
        """Decorator to run a single subtask
        The decorated function should create and return subtask. The decorator will
        enqueue subtask, wait for it and return the resulted task instance.

        NOTE: The decorated function will be executed only once.
        Subsequent calls of decorated functions will return the same result.
        """
        def decorator(f):
            @wraps(f)
            def decorated(self, *args, **kwargs):
                return self.run_subtask(
                    f.__name__,
                    lambda: f(self, *args, **kwargs),
                    timeout=timeout
                )
            return decorated
        return decorator

    @staticmethod
    def subtasks(timeout=DEFAULT_TIMEOUT):
        """Decorator to run multiple subtasks
        The decorated function should create and return list of subtasks.
        The decorator will enqueue subtasks, wait for all of them to finish
        and return the list of resulted task instances.

        NOTE: The decorated function will be executed only once.
        Subsequent calls of decorated functions will return the same result.
        """
        def decorator(f):
            @wraps(f)
            def decorated(self, *args, **kwargs):
                return self.run_subtasks(
                    f.__name__,
                    lambda: f(self, *args, **kwargs),
                    timeout=timeout
                )
            return decorated
        return decorator
