# coding: utf-8

from enum import Enum
import logging

import sandbox.common.types.task as ctt
from sandbox import sdk2


class NodeStatus(str, Enum):
    NOT_RUNNED = 'not_runned'
    WAITING = 'waiting'
    FAILURE = 'failure'
    SUCCESS = 'success'


def get_tasks_from_sandbox(task_ids):
    return list(sdk2.Task.find(id=task_ids).limit(len(task_ids)))


def get_node_status_by_task_status(task_status, node_task_statuses):
    for node_status, task_statuses in node_task_statuses.items():
        if task_status in task_statuses:
            return node_status
    return None


class TaskGraph:
    NODE_TASK_STATUSES = dict({
        NodeStatus.FAILURE: [
            ctt.Status.FAILURE,
            ctt.Status.EXCEPTION,
            ctt.Status.STOPPED,
        ],
        NodeStatus.SUCCESS: [ctt.Status.SUCCESS],
    })

    def __init__(self, state=dict(), get_tasks=get_tasks_from_sandbox):
        self.nodes = dict()
        self.state = state
        self.get_tasks = get_tasks

    def add_node(self, name, depends, run, callback=None):
        node = dict({
            'name': name,
            'depends': depends,
            'run': run,
            'callback': callback,
        })
        self.nodes[name] = node
        if name not in self.state:
            self.state[name] = dict({
                'task_id': None,
                'status': NodeStatus.NOT_RUNNED,
            })

    def serialize(self):
        return self.state

    def get_actual_node_statuses(self, nodes):
        task_ids = dict()
        for node in nodes:
            task_ids[self.state[node]['task_id']] = node

        finded_tasks = self.get_tasks(task_ids.keys())
        res = dict()
        for finded_task in finded_tasks:
            node = task_ids[finded_task.id]
            new_status = get_node_status_by_task_status(finded_task.status, TaskGraph.NODE_TASK_STATUSES)
            if new_status:
                res[node] = new_status
        return res

    def run_tasks(self):
        is_need_to_process_graph = True

        while is_need_to_process_graph:
            logging.debug('run_tasks')
            not_runned_nodes = self.get_nodes_with_status(NodeStatus.NOT_RUNNED)
            success_nodes = self.get_nodes_with_status(NodeStatus.SUCCESS)

            logging.debug('not_runned_nodes')
            logging.debug(not_runned_nodes)
            logging.debug('success_nodes')
            logging.debug(success_nodes)
            callback_runned = False
            for node in not_runned_nodes:
                can_start = True
                node_obj = self.nodes[node]
                logging.debug('node_obj')
                logging.debug(node_obj)
                for depend in node_obj['depends']:
                    if depend not in success_nodes:
                        can_start = False
                        break
                if can_start:
                    logging.debug('can_start')
                    logging.debug(node_obj)
                    try:
                        task = node_obj['run']()
                    except Exception as ex:
                        self.state[node] = dict({
                            'status': NodeStatus.FAILURE,
                            'task_id': None,
                        })
                        logging.debug('exception')
                        logging.debug(ex)
                    else:
                        if task is None or task is True:
                            # если колбэк, и он не вернул False
                            callback_runned = True
                            self.state[node] = dict({
                                'status': NodeStatus.SUCCESS,
                                'task_id': None,
                            })
                        elif task is False:
                            # Если колбэк вернул False
                            self.state[node] = dict({
                                'status': NodeStatus.FAILURE,
                                'task_id': None,
                            })
                        elif task:
                            # если вернулась таска
                            self.state[node] = dict({
                                'status': NodeStatus.WAITING,
                                'task_id': task.id,
                            })
            is_need_to_process_graph = callback_runned

    def process(self):
        logging.debug('TaskGraph process')
        self.update_waiting_nodes()
        self.run_tasks()

    def update_waiting_nodes(self):
        waiting_nodes = self.get_nodes_with_status(NodeStatus.WAITING)
        logging.debug('waiting_nodes')
        logging.debug(waiting_nodes)

        if len(waiting_nodes):
            node_statuses = self.get_actual_node_statuses(waiting_nodes)

            for node, new_status in node_statuses.items():
                prev_status = self.get_node_status(node)
                if new_status:
                    self.state[node]['status'] = new_status
                    if not prev_status == new_status:
                        node_callback = self.nodes[node]['callback']
                        if node_callback:
                            try:
                                node_callback(
                                    status=new_status,
                                    task_id=self.state[node]['task_id']
                                )
                            except:
                                logging.error('Error in callback for node ', node)

    def get_waiting_task_ids(self):
        nodes = self.get_nodes_with_status(NodeStatus.WAITING)
        return [self.state[node]['task_id'] for node in nodes]

    def get_nodes_with_status(self, status):
        return [node for node in self.nodes.keys() if self.state[node]['status'] == status]

    def get_node_status(self, node):
        return self.state[node]['status']

    def get_node_task_id(self, node):
        return self.state[node]['task_id']

    def has_node(self, node):
        return node in self.state

    def is_finished(self):
        waiting_nodes = self.get_nodes_with_status(NodeStatus.WAITING)
        not_runned_nodes = self.get_nodes_with_status(NodeStatus.NOT_RUNNED)
        if len(waiting_nodes) > 0 or len(not_runned_nodes) > 0:
            return False
        else:
            return True
