
import logging
from collections import defaultdict

from django.db import transaction

from wiki.intranet.models.intranet_extensions import Group

logger = logging.getLogger(__name__)


class Tree(object):
    def __init__(self, tree_id):
        self.id = tree_id
        self.adj = defaultdict(list)


class Forest(object):
    def __init__(self):
        self.trees = {}
        self.adj = defaultdict(list)

    def get_free_tree_idx(self):
        existing_ids = set(self.trees.keys())
        j = 1
        while True:
            if j not in existing_ids:
                return j
            j += 1

    def stats(self):
        print([(k, len(v.adj)) for k, v in list(self.trees.items())])

    def make_node(self, src):
        node = TreeNode()
        tree_id = src.tree_id
        if tree_id not in self.trees:
            self.trees[tree_id] = Tree(tree_id)

        tree = self.trees[tree_id]
        node.id = src.id
        node.left = src.lft
        node.right = src.rght
        node.tree_id = src.tree_id
        node.level = src.level
        node.parent_id = src.parent_id

        tree.adj[node.parent_id].append(node)
        self.adj[node.parent_id].append(node)

    def reload_trees(self):
        self.trees = {}
        for _, nodes in self.adj.items():
            for node in nodes:
                tree_id = node.tree_id
                if tree_id not in self.trees:
                    self.trees[tree_id] = Tree(tree_id)

                tree = self.trees[tree_id]
                tree.adj[node.parent_id].append(node)

    def get_corrupted_trees(self):
        corrupted = set()
        roots = self.adj[None]
        for node in roots:
            logger.debug('check tree with root %s' % node)
            if not self.tree_checker(node.tree_id, node):
                corrupted.add(node.tree_id)

        # У дерева должен быть только один корень
        for tree_id, tree in self.trees.items():
            if len(tree.adj[None]) > 1:
                logger.debug('tree %s has more than one root node' % tree_id)
                corrupted.add(tree_id)

        return corrupted

    def tree_checker(self, tree_id, node, left=1, level=0, max_level=None):
        if max_level and level >= max_level:
            return True

        if node.tree_id != tree_id:
            logger.debug('node %s is from different tree' % node.id)
            return False

        child_nodes = self.adj[node.id]
        child_nodes_sorted = sorted(child_nodes, key=lambda x: x.left)
        offset = '    ' * level
        if len(child_nodes_sorted) > 0:
            correct_right = child_nodes_sorted[-1].right + 1
            logger.debug('%s Node: %s, Children: %s' % (offset, node, child_nodes_sorted))

            if node.level != level or node.left != left or node.right != correct_right:
                logger.debug('%s Corrupt, expected L%s:%s<->%s' % (offset, level, left, correct_right))
                return False

            node_left = left + 1
            for child_node in child_nodes_sorted:
                if not self.tree_checker(tree_id, child_node, node_left, level + 1):
                    return False

                node_left = child_node.right + 1
        else:
            logger.debug('%s Node: %s, End Node' % (offset, node))
            if node.level != level or node.left != left or node.right != left + 1:
                logger.debug('%s Corrupt, expected L%s:%s<->%s' % (offset, level, left, left + 1))
                return False

        return True

    def rebuild_tree(self, tree_id, with_reload=True):
        roots = self.trees[tree_id].adj[None]
        for node in roots:
            self._rebuild_helper(node, 1, tree_id)
            tree_id = self.get_free_tree_idx()

        if with_reload:
            self.reload_trees()

    def _rebuild_helper(self, node, left, tree_id, level=0):
        right = left + 1
        children = self.adj[node.id]
        for child in children:
            right = self._rebuild_helper(child, right, tree_id, level + 1)

        logger.debug('update node {} left {} right {} tree id {}'.format(node.id, left, right, tree_id))

        if tree_id not in self.trees:
            self.trees[tree_id] = Tree(tree_id)

        node.left = left
        node.right = right
        node.tree_id = tree_id
        node.level = level
        node.dirty = True

        return right + 1

    def apply_changes(self):
        if len(self.get_corrupted_trees()) > 0:
            raise RuntimeError('Forest contains corrupted trees')

        for tree_id, tree in self.trees.items():
            with transaction.atomic():
                for parent_id, nodes in tree.adj.items():
                    for node in nodes:
                        if node.dirty:
                            qs = Group.objects.filter(pk=node.id)
                            qs.update(lft=node.left, rght=node.right, tree_id=node.tree_id, level=node.level)
                            node.dirty = False


class TreeNode:
    id = None
    left = None
    right = None
    tree_id = None
    level = None
    parent_id = None
    dirty = False  # Показывает необходимость обновить запись в базе после ребилда

    def __str__(self):
        return '{#%s:L%s %s<->%s}' % (self.id, self.level, self.left, self.right)

    def __repr__(self):
        return self.__str__()


def plant_forest():
    f = Forest()
    for group in Group.objects.all():
        f.make_node(group)
    return f
