import argparse
import graphviz
from collections import namedtuple
from lxml import etree as ET

"""
Produces a nice looking diagram with all editor categories and their relationships.

Usage examples

Produce a diagram with all categories:
./category_graph_visualizer

Produce a diagram with categories connected somehow with the given category:
./category_graph_visualizer --category rd
"""


DEFAULT_EDITOR_CONFIG = '/usr/share/yandex/maps/wiki/editor/editor.xml'

Edge = namedtuple("Edge", ["tail", "head", "attributes"])


class Graph(object):
    def __init__(self):
        self._nodes = {}
        self._edges = []

    def add_node(self, cat_id):
        self._nodes[cat_id] = {}
        return self.get_node(cat_id)

    def has_node(self, cat_id):
        return cat_id in self._nodes

    def get_node(self, cat_id):
        return self._nodes[cat_id]

    def add_edge(self, tail, head, **kwargs):
        edge = Edge(tail, head, kwargs)
        self._edges.append(edge)
        return edge

    def has_edge(self, tail, head):
        return any([edge.tail == tail and edge.head == head for edge in self._edges])


def make_headtaillabel(min, max):
    if min == max:
        return min

    if max == 'unbounded':
        max = ''

    return min + '..' + max


def make_taillabel(role):
    return make_headtaillabel(role.get('master-min-occurs'),
                              role.get('master-max-occurs'))


def make_headlabel(role):
    return make_headtaillabel(role.get('min-occurs'),
                              role.get('max-occurs'))


def add_node(G, config, cat_id):
    if G.has_node(cat_id):
        return

    n = G.add_node(cat_id)

    category = config.xpath('categories/category[@id="' + cat_id + '"]')[0]
    template_id = category.get('template-id')
    geometry = config.xpath('category-templates/category-template[@id="' + template_id + '"]/geometry')
    if geometry:
        if geometry[0].get('type') == 'point':
            n['color'] = '#6BFFC1'
            n['style'] = 'filled'
        elif geometry[0].get('type') == 'polyline':
            n['color'] = '#B7C7FF'
            n['style'] = 'filled'
        elif geometry[0].get('type') == 'polygon':
            n['color'] = '#FF96E6'
            n['style'] = 'filled'

    if category.get('complex') == 'true':
        n['shape'] = 'box'


def add_role(G, config, role, master_cat_id, slave_cat_id, all_roles=False):
    if G.has_edge(master_cat_id, slave_cat_id) and not all_roles:
        return

    if all_roles:
        e = G.add_edge(
            master_cat_id,
            slave_cat_id,
            label=role.get('id'),
            taillabel=make_taillabel(role),
            headlabel=make_headlabel(role))
    else:
        e = G.add_edge(master_cat_id, slave_cat_id)

    if role.get('geom-part') == "true":
        e.attributes['color'] = 'blue'

    if role.get('table-row') == 'true':
        n = G.get_node(slave_cat_id)
        n['color'] = '#FFCC00'
        n['style'] = 'filled'


def add_slaves(G, config, master_category_ids, all_roles=False):
    new_categories_ids = []

    categories = config.xpath('categories/category')
    for category in categories:
        cat_id = category.get('id')
        if cat_id in master_category_ids:
            for role in category.xpath('relations/role'):
                slave_cat_id = role.get('category-id')

                if not G.has_node(slave_cat_id):
                    add_node(G, config, slave_cat_id)
                    new_categories_ids.append(slave_cat_id)

                add_role(G, config, role, cat_id, slave_cat_id, all_roles)

    if len(new_categories_ids) > 0:
        add_slaves(G, config, new_categories_ids, all_roles)


def add_masters(G, config, slave_category_ids, all_roles=False):
    new_categories_ids = []

    categories = config.xpath('categories/category')
    for category in categories:
        cat_id = category.get('id')

        for role in category.xpath('relations/role'):
            slave_cat_id = role.get('category-id')

            if slave_cat_id in slave_category_ids:
                if not G.has_node(cat_id):
                    add_node(G, config, cat_id)
                    new_categories_ids.append(cat_id)

                add_role(G, config, role, cat_id, slave_cat_id, all_roles)

    if len(new_categories_ids) > 0:
        add_masters(G, config, new_categories_ids, all_roles)


def run():
    parser = argparse.ArgumentParser(description='Print category graph')
    parser.add_argument('--category', help='Print graph of all categories connected to the specified one')
    parser.add_argument('--config', help='Path to editor config', default=DEFAULT_EDITOR_CONFIG)
    parser.add_argument('--dot', help='Output in dot format', action='store_true')
    parser.add_argument('--stdout', help='Output to stdout', action='store_true')
    parser.add_argument('--all-roles', help='Create edges for all roles', action='store_true')
    args = parser.parse_args()

    config = ET.parse(args.config)
    config.xinclude()

    G = Graph()

    if args.category is not None:
        add_node(G, config, args.category)
        add_slaves(G, config, [args.category], args.all_roles)
        add_masters(G, config, [args.category], args.all_roles)
    else:
        categories = config.xpath('categories/category')
        for category in categories:
            cat_id = category.get('id')
            if not G.has_node(cat_id):
                add_node(G, config, cat_id)

            for role in category.xpath('relations/role'):
                slave_cat_id = role.get('category-id')
                if not G.has_node(slave_cat_id):
                    add_node(G, config, slave_cat_id)
                add_role(G, config, role, cat_id, slave_cat_id, args.all_roles)

    dot = graphviz.Digraph(strict=False)
    # G.graph_attr['nodesep'] = '0.2'
    # G.graph_attr['ranksep'] = '1.0'
    dot.graph_attr['rankdir'] = 'LR'
    dot.node_attr['fontsize'] = '18'

    for node_name, attrs in G._nodes.iteritems():
        dot.node(node_name, **attrs)

    for edge in G._edges:
        dot.edge(edge.tail, edge.head, **edge.attributes)

    filename = 'categories_{}.dot'.format(args.category) if args.category else 'categories.dot'

    if args.stdout:
        print(dot.source)
    elif args.dot:
        dot.save(filename)
    else:
        dot.render(filename)
