import constants as c


def merge_two_shards(sh1, sh2):
    return Shard(sh1.ids | sh2.ids, sh1.graph)


def to_yamake(rulenames):
    return '''OWNER(
    g:begemot
)

IF (NOT AUTOCHECK)

    PACKAGE()

    PEERDIR(
{rulepaths}
    )

    END()

ENDIF()'''.format(rulepaths='\n'.join([' '*8 + c.RULE_PREFIX + name for name in sorted(rulenames)]))


class Shard:

    def __init__(self, ids, graph):
        self.ids = set(ids)
        self.graph = graph  # just a 'const' pointer
        self._path_times = None
        self._longest_parents = None

    def get_size(self):
        return sum([self.graph.sizes[rid] for rid in self.ids])

    def get_names(self):
        return [self.graph.id2name[rid] for rid in self.ids]

    def merge_with(self, other):
        self.ids |= other.ids
        self._path_times = None
        self._longest_parents = None

    def _dfs(self, rid):
        self._used.add(rid)
        children = self.graph.get_children(rid) & self.ids
        for child_id in children:
            if child_id not in self._used:
                self._dfs(child_id)
        self._order.append(rid)

    def _topological_sort(self):
        self._order = list()
        self._used = set()
        for rid in self.ids:
            if rid not in self._used:
                self._dfs(rid)
        return list(reversed(self._order))

    def _count_path_times(self):
        self._path_times = dict()
        self._longest_parents = dict()
        for rid in self._topological_sort():
            parents = self.graph.get_parents(rid) & self.ids
            max_parent_time = 0
            longest_parent_id = None
            if parents:
                parents = [(self._path_times[par_id], par_id) for par_id in parents]
                max_parent_time, longest_parent_id = max(parents)
            self._path_times.update({rid: self.graph.times[rid] + max_parent_time})
            self._longest_parents.update({rid: longest_parent_id})

    def get_longest_path(self):
        if self._path_times is None:
            self._count_path_times()
        max_time, max_rid = max([(ptime, rid) for rid, ptime in self._path_times.items()])
        longest_path = [max_rid]

        while True:
            parent_id = self._longest_parents[longest_path[-1]]
            if parent_id is None:
                return list(reversed(longest_path))
            longest_path.append(parent_id)

    def get_longest_path_time(self):
        if self._path_times is None:
            self._count_path_times()
        return max(self._path_times.values())
