import logging

from collections import namedtuple

import sandbox.common.types.task as ctt
from sandbox.common import rest
from sandbox.common.errors import TaskFailure, TaskError
from sandbox.common.utils import get_task_link


logger = logging.getLogger(__name__)


class SubtasksByStatus(namedtuple("_SubtasksByStatus", ['failure', 'broken', 'succeed', 'queue', 'other'])):

    @property
    def running(self):
        return self.queue + self.other


def split_subtasks(task_id):

    client = rest.Client()

    SUCCEED = frozenset(ctt.Status.Group.SUCCEED)
    FAILURE = frozenset(ctt.Status.Group.FINISH) - SUCCEED
    BREAK = frozenset(ctt.Status.Group.BREAK)
    DRAFT = frozenset(ctt.Status.Group.DRAFT)
    QUEUE = frozenset(ctt.Status.Group.QUEUE)

    OTHER = None

    WELL_DEFINED_GROUPS = (SUCCEED, FAILURE, BREAK, DRAFT, QUEUE)
    ALL_GROUPS = WELL_DEFINED_GROUPS + (OTHER,)

    subtask_ids = {gr: [] for gr in ALL_GROUPS}

    data = client.task[task_id].children.read()

    for item in data['items']:
        status = item['status']
        try:
            group = next(gr for gr in WELL_DEFINED_GROUPS if status in gr)
        except StopIteration:
            group = OTHER
        subtask_ids[group].append(item['id'])

    result = SubtasksByStatus(
        failure=subtask_ids[FAILURE],
        broken=subtask_ids[BREAK],
        succeed=subtask_ids[SUCCEED],
        queue=subtask_ids[QUEUE],
        other=subtask_ids[OTHER],
    )
    return result


def check_and_handle_subtask_failure(subtasks, set_info=None, kill_only_queued_on_failure=True):
    if not subtasks.failure:
        return

    client = rest.Client()
    try:
        client.batch.tasks.stop.update(subtasks.queue if kill_only_queued_on_failure else subtasks.running)
    except BaseException:
        logger.warning(
            "Batch stop of queued%s children failed",
            "" if kill_only_queued_on_failure else " and running",
            exc_info=True
        )
    msg = "There are FAILED subtasks"
    set_subtasks_info(msg, subtasks.failure, set_info=set_info)
    raise TaskFailure(msg)


def set_subtasks_info(header, task_ids, set_info=None):
    if set_info is None:
        return
    set_info(
        "{}:\n{}".format(
            header,
            '\n'.join('<a href="{}">{}</a>'.format(get_task_link(tid), tid) for tid in task_ids)
        ),
        do_escape=False,
    )


def batched(iterable, batch_size):
    for batch_start in range(0, len(iterable), batch_size):
        yield iterable[batch_start:batch_start + batch_size]


def run_tasks(self, ids, batch_size=5):
    rest.Client.DEFAULT_TIMEOUT = 100
    client = rest.Client()
    states = []

    for task_ids_batch in batched(ids, batch_size):
        logger.info('Run tasks batch %s', task_ids_batch)
        batch_states = client.batch.tasks.start.update(task_ids_batch)
        logger.info('Batch states: %s', batch_states)
        states += batch_states

    errors = []
    for task_state in states:
        if task_state['status'] == u'ERROR':
            errors.append(
                'Failed to start task {task_link}: {message}'
                .format(task_link=get_task_link(task_state['id']),
                        message=task_state['message'])
            )
        elif task_state['status'] == u'WARNING':
            self.set_info(
                'Got warning while running task <a href="{task_link}" target="_blank">#{task_id}</a>: {message}'
                .format(task_link=get_task_link(task_state['id']),
                        task_id=task_state['id'],
                        message=task_state['message']),
                do_escape=False
            )
    if errors:
        raise TaskError('\n\n'.join(errors))
