from enum import Enum
from collections import namedtuple

from sandbox import sdk2
from sandbox.common.types import task as ctt
from sandbox.common.errors import TaskError, TaskFailure


def schedule_tasks(
    self,
    task_class,
    parameters,
    task_count,
    subtask_data_list,
    checker,  # Needed to get generation array dimensions
    description='Child run',
    tasks_resource=None,
    **kwargs
):
    for index in xrange(task_count):
        parameters_dict = dict(parameters)
        parameters_dict.update(kwargs)
        subtask = task_class(
            self,
            description=description + '\nSlot {} of {}'.format(index, task_count),
            hints=list(self.hints),
            tags=self.Parameters.tags,
            hosts_slots_count=task_count,
            hosts_slot_index=index,
            **parameters_dict
        )
        if tasks_resource:
            subtask.Requirements.tasks_resource = tasks_resource
            subtask.save()
        subtask.enqueue()
        subtask_data_list.append(checker.initial_subtask_data(subtask))


def check_and_reschedule_tasks(
    self,
    subtask_data_list,
    checker,
    description='Child rerun',
    tasks_resource=None,
    **kwargs
):
    need_wait = False
    for index in xrange(len(subtask_data_list)):
        subtask_data = subtask_data_list[index]
        subtask = checker.get_subtask_from_data(subtask_data)
        parameters_dict = dict(subtask.Parameters)
        parameters_dict.update(kwargs)
        check_results = checker.check(subtask_data)
        if check_results.need_rerun:
            new_subtask = type(subtask)(
                self,
                description=description + '\nSlot {} of {}'.format(index, len(subtask_data_list)),
                hints=list(self.hints),
                tags=self.Parameters.tags,
                hosts_slots_count=len(subtask_data_list),
                hosts_slot_index=index,
                **parameters_dict
            )
            if tasks_resource:
                new_subtask.Requirements.tasks_resource = tasks_resource
                new_subtask.save()
            new_subtask.enqueue()
            subtask_data_list[index] = checker.get_subtask_data_from_subtask_and_check_data(new_subtask, check_results.check_data)
        need_wait |= check_results.need_wait
    return need_wait


CheckResults = namedtuple('CheckResults', [
    'need_wait',
    'need_rerun',
    'check_data'
])


class TaskChecker(object):

    class TaskCheckStates(object):
        class Limited(Enum):
            failure = 1
            abandoned = 2
            other_break = 3

        class NonLimited(Enum):
            executing = 1
            success = 2

        class __metaclass__(type):
            def __iter__(self):
                for item in self.Limited:
                    yield item
                for item in self.NonLimited:
                    yield item

    def get_subtask_from_data(self, subtask_data):
        return sdk2.Task[subtask_data[0]]

    def __init__(self, **kwargs):
        self.limits = {}
        for key in kwargs:
            if key.endswith('_limit'):
                new_key = key[:-len('_limit')]
                self.limits[new_key] = kwargs[key]

        assert len(self.limits) == len(self.TaskCheckStates.Limited), 'Not enough limits provided'

    def initial_subtask_data(self, subtask):
        return [subtask.id] + [0 for _ in xrange(len(self.TaskCheckStates.Limited))]

    def check(self, subtask_data):
        task = sdk2.Task[subtask_data[0]]
        generation_dict = dict(zip((member.name for member in self.TaskCheckStates.Limited), self.get_check_data_from_subtask_data(subtask_data)))
        state = self.get_state(task)
        if state in self.TaskCheckStates.Limited:
            generation_dict[state.name] += 1
            if generation_dict[state.name] > self.limits[state.name]:
                if state == self.TaskCheckStates.Limited.failure:
                    raise TaskFailure('Task {} is in FAILURE status, maximum run attempts reached'.format(task.id))
                elif state == self.TaskCheckStates.Limited.abandoned:
                    raise TaskError('Task {} has abandoned host, maximum run attempts reached'.format(task.id))
                elif state == self.TaskCheckStates.Limited.other_break:
                    raise TaskError('Task {} is in BREAK status ({}) and it did not abandon its host, maximum amount reached'.format(task.id, repr(task.status)))
            return CheckResults(
                need_wait=True,
                need_rerun=True,
                check_data=[generation_dict[task_state.name] for task_state in self.TaskCheckStates.Limited]
            )
        elif state in self.TaskCheckStates.NonLimited:
            if state == self.TaskCheckStates.NonLimited.executing:
                return CheckResults(
                    need_wait=True,
                    need_rerun=False,
                    check_data=None
                )
            elif state == self.TaskCheckStates.NonLimited.success:
                return CheckResults(
                    need_wait=False,
                    need_rerun=False,
                    check_data=None
                )

    def get_subtask_data_from_subtask_and_check_data(self, subtask, check_data):
        return [subtask.id] + check_data

    def get_check_data_from_subtask_data(self, subtask_data):
        return subtask_data[1:]

    def get_state(self, task):
        if task.status == ctt.Status.FAILURE:
            return self.TaskCheckStates.Limited.failure
        elif task.status in ctt.Status.Group.BREAK and task.Context.abandoned_host:
            return self.TaskCheckStates.Limited.abandoned

        elif task.status in ctt.Status.Group.EXECUTE | ctt.Status.Group.QUEUE:
            return self.TaskCheckStates.NonLimited.executing
        elif task.status == ctt.Status.SUCCESS:
            return self.TaskCheckStates.NonLimited.success
        else:
            return self.TaskCheckStates.Limited.other_break


def get_task_list_from_data_iterable(iterable):
    return [sdk2.Task[item[0]] for item in iterable]
