import logging
import numbers
import six

from sandbox import sdk2
from sandbox.sandboxsdk.task import SandboxTask
from sandbox.sdk2.helpers import ProcessLog

from sandbox.common.types import task as task_type
from sandbox.common.errors import SandboxException, TaskFailure, TaskError, VaultError
from sandbox.common.types.resource import State

from sandbox.projects.yabs.sandbox_task_tracing import trace_calls
from sandbox.projects.yabs.sandbox_task_tracing.wrappers import subprocess as sp

logger = logging.getLogger(__name__)


@trace_calls(save_arguments=(1, 'tasks'))
def check_tasks(self, tasks, raise_on_fail=True, callback=None, wait_all=True, raise_task_failure=False):
    """
    Wait for specified task(s) switched into FINISH or BREAK statuses.

    :param tasks: task or list of task objects or identifiers the task will wait for.
                  Task is an integer or an instance of sdk2.Task
    :param raise_on_fail: bool, raise exception if any task did not switch
                          in one of statuses: SUCCESS, RELEASING, RELEASED, NOT_RELEASED
    :param callback: callable(task), callback to be called on every task's finish,
                     should accept finished task as first argument
    :param wait_all: bool, wait for all or any of tasks specified
    :param raise_task_failure: bool, raise TaskFailure on not successful status
    :return: list of tuple(task, task's status)
    """
    if not tasks:
        return []

    if not hasattr(tasks, '__iter__'):
        tasks = [tasks]
    task_objects = []
    task_ids = []

    for task in tasks:
        if isinstance(task, sdk2.Task):
            task_objects.append(task)
            task_ids.append(task.id)
        elif isinstance(task, numbers.Integral):
            found_task = sdk2.Task.find(id=task, children=True, hidden=True).first()
            if found_task is None:
                raise TaskFailure('Task #{id} not found'.format(id=task))
            task_objects.append(found_task)
            task_ids.append(task)
        else:
            raise TaskError(
                'Unexpected type: "{}". Items of "tasks" should be either sdk2.Task instances or integral number'
                .format(type(task))
            )

    statuses_to_wait = task_type.Status.Group.FINISH | task_type.Status.Group.BREAK
    finished_tasks = [task for task in task_objects if task.status in statuses_to_wait]
    running_tasks = [task for task in task_objects if task.status not in statuses_to_wait]
    running_task_ids = [task.id for task in running_tasks]

    if callback:
        for task in finished_tasks:
            callback(self, task)

    if len(finished_tasks) != len(tasks):
        if isinstance(self, sdk2.Task):
            raise sdk2.WaitTask(running_tasks, statuses_to_wait, wait_all=wait_all)
        elif isinstance(self, SandboxTask):
            self.wait_tasks(running_task_ids, statuses_to_wait, wait_all=wait_all)

    broken_tasks = filter(lambda task: task.status in task_type.Status.Group.BREAK, task_objects)
    logger.debug('Broken tasks: {}'.format(broken_tasks))
    if raise_on_fail and broken_tasks:
        error_class = TaskFailure if raise_task_failure else TaskError
        raise error_class('Tasks {tasks_ids} are BROKEN'.format(tasks_ids=[task.id for task in broken_tasks]))

    failed_tasks = filter(lambda task: task.status not in task_type.Status.Group.SUCCEED, task_objects)
    logger.debug('Failed tasks: {}'.format(failed_tasks))
    if raise_on_fail and failed_tasks:
        raise TaskFailure('Tasks {tasks_ids} are FAILED'.format(tasks_ids=[task.id for task in failed_tasks]))

    return map(lambda task: (task, task.status), task_objects)


@trace_calls
def run_process(task_instance, command, logger_name):
    with ProcessLog(task_instance, logger=logging.getLogger(logger_name)) as process_context:
        code = sp.popen_and_wait(command, stdout=process_context.stdout, stderr=sp.STDOUT)
    assert not code, 'Command {command} failed with code {code}'.format(command=get_command_str(command),
                                                                        code=code)


def get_command_str(command):
    if isinstance(command, six.string_types):
        return command

    return ' '.join(command)


def CustomAssert(expression, exception_string='', exception_class=SandboxException):
    if not expression:
        raise exception_class(exception_string)


def try_get_from_vault(self, key):
    try:
        return sdk2.Vault.data(self.author, key)
    except VaultError:
        logging.warning('Failed to get vault item by author, trying owner')
    try:
        return sdk2.Vault.data(self.owner, key)
    except VaultError:
        logging.warning('Failed to get vault item by owner, trying anything available')
    return sdk2.Vault.data(key)


def find_last_ready_resource(res_type, attrs=None, **kwargs):
    if isinstance(res_type, six.string_types):
        res_type = sdk2.Resource[res_type]
    attrs = attrs or {}
    resource = res_type.find(attrs=attrs, state=State.READY, order="-id", **kwargs).first()
    if resource:
        return resource
    raise TaskError("Cannot find last ready %s with attributes %s and %s" % (res_type, attrs, kwargs))
