import logging

import graphviz

from tasklet.api import spy_pb2


logger = logging.getLogger(__name__)

STATUS_COLORS = {
    'SCHEDULED': 'white',
    'LAUNCHED': 'lightskyblue1',
    'SUCCESS': 'aquamarine2',
    'FAILURE': 'pink',
}


def draw_graph(log, run_id):
    graph = graphviz.Digraph(engine='dot')

    current_status = {}
    event_order = [[], [], []]
    tasklet_id_to_event = {}
    for event in log:
        node_id = event.tasklet_id
        current_status[node_id] = spy_pb2.Event.State.Name(event.state)

    for event in log:
        if event.state == spy_pb2.Event.State.Value('SCHEDULED'):
            edge_from = event.parent_id
            edge_to = event.tasklet_id
            new_node_tag = event.tasklet_name

            graph.node(
                edge_to, label='{}\n{}'.format(new_node_tag, len(event_order[0])),
                style='filled',
                color='black',
                fillcolor=STATUS_COLORS.get(current_status.get(edge_to), 'white'),
                shape='box',
            )
            graph.edge(edge_from, edge_to)

            tasklet_id_to_event[event.tasklet_id] = len(event_order[0])
            event_order[0].append(edge_to)
            event_order[1].append(event.metadata)
            event_order[2].append('')
        if event.state == spy_pb2.Event.State.Value('SUCCESS'):
            event_order[2][tasklet_id_to_event[event.tasklet_id]] = event.metadata

    graph.node(
        'legend',
        '{{ {} }}|{{ {} }}|{{ {} }}|{{ {} }}'.format(
            '|'.join(list(map(str, range(len(event_order[0]))))),
            '|'.join(event_order[0]),
            '|'.join(event_order[1]),
            '|'.join(event_order[2])
        ),
        shape='record',
    )
    graph.edge('legend', '', color='white')

    return graph.pipe(format='svg')
