# -*- coding: utf-8 -*-

import logging

from sandbox import sdk2
from sandbox.common.types import task as ctt
from sandbox.common.errors import TaskStop


# Limit is required for `Task.find`.
WAIT_TASKS_LIMIT = 1024


class TaskDependenciesOptions(object):
    def __init__(self, wait_output_parameters=None, wait_timeout=None, waking_output_parameter=None):
        self.wait_output_parameters = wait_output_parameters
        self.wait_timeout = wait_timeout
        self.waking_output_parameter = waking_output_parameter


class TaskDependenciesManager(object):
    def __init__(self, task, options):
        """
        :type task: sandbox.sdk2.Task
        :type options: dict
        """
        self.task = task
        self.options = TaskDependenciesOptions(**options)

    def wait(self):
        """
        Waits for the output parameters this task depends on.
        """
        wait_output = self._get_wait_output_parameters()

        if not wait_output:
            logging.debug('no tasks to wait')
            return

        with self.task.memoize_stage.wait_task_dependencies(commit_on_entrance=False):
            timeout = self.options.wait_timeout or self.task.Parameters.kill_timeout

            raise sdk2.WaitOutput(wait_output, wait_all=True, timeout=timeout)

        self._check_awaited_tasks(wait_output)

    def awake(self, status):
        """
        "Wakes up" dependent tasks.
        """
        param_idx = self.options.waking_output_parameter

        if not param_idx:
            logging.debug('no waking output parameter to set')
            return

        param_name = self._format_waking_output_parameter(param_idx)

        # Just in case.
        if getattr(self.task.Parameters, param_name) is not None:
            logging.debug('waking output parameter "{}" is already set'.format(param_name))
            return

        logging.debug('set waking output parameter "{}"'.format(param_name))

        setattr(self.task.Parameters, param_name, dict(status=status))

    def _get_wait_output_parameters(self):
        raw_wait_output = self.options.wait_output_parameters

        if not raw_wait_output:
            return None

        wait_output = {}

        for task_id, param in raw_wait_output.items():
            wait_output[int(task_id)] = map(self._format_waking_output_parameter, param if type(param) is list else [param])

        return wait_output

    def _check_awaited_tasks(self, wait_output):
        tasks = list(sdk2.Task.find(id=wait_output.keys()).limit(WAIT_TASKS_LIMIT))

        for task in tasks:
            for param_name in wait_output[int(task)]:
                output = getattr(task.Parameters, param_name)

                if type(output) is not dict:
                    raise TaskStop('Expected task #{} to set waking output parameter "{}"'.format(task.id, param_name))

                if output['status'] != ctt.Status.SUCCESS:
                    raise TaskStop('Awaited task #{} is failed'.format(task.id))

    @staticmethod
    def _format_waking_output_parameter(param_idx):
        return 'waking_output_parameter_{}'.format(param_idx)
