package ru.yandex.webmaster3.core.http.util;

import org.apache.commons.lang3.tuple.Pair;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * @author avhaliullin
 */
public class TopologicalSort {
    public static <T> List<T> sort(Collection<T> nodes, Map<T, Set<T>> edges) {
        Iterator<Map.Entry<T, Set<T>>> it = edges.entrySet().iterator();
        while (it.hasNext()) {
            Set<T> dests = it.next().getValue();
            if (dests == null || dests.isEmpty()) {
                it.remove();
            }
        }

        Map<T, Set<T>> reversedEdges = new HashMap<>();
        for (Map.Entry<T, Set<T>> entry : edges.entrySet()) {
            T src = entry.getKey();
            for (T dest : entry.getValue()) {
                reversedEdges.computeIfAbsent(dest, ign -> new HashSet<>()).add(src);
            }
        }

        Queue<T> safeNodes = new ArrayDeque<T>();
        for (T node : nodes) {
            if (!reversedEdges.containsKey(node) || reversedEdges.get(node).isEmpty()) {
                safeNodes.add(node);
            }
        }

        List<T> result = new ArrayList<>(nodes.size());
        while (!safeNodes.isEmpty()) {
            T node = safeNodes.poll();
            result.add(node);
            Set<T> dests = edges.get(node);
            if (dests != null && !dests.isEmpty()) {
                for (T dest : dests) {
                    Set<T> destReverseEdges = reversedEdges.get(dest);
                    destReverseEdges.remove(node);
                    if (destReverseEdges.isEmpty()) {
                        reversedEdges.remove(dest);
                        safeNodes.add(dest);
                    }
                }
            }
            edges.remove(node);
        }
        if (!edges.isEmpty()) {
            throw new IllegalArgumentException("Topological sort failed: there is cycle in subgraph " +
                    edges.keySet().stream().map(Object::toString).collect(Collectors.joining(", "))
            );
        }
        return result;
    }

    public static <T> Pair<T, T> verifySorted(List<T> nodes, Map<T, Set<T>> edges) {
        Set<T> visited = new HashSet<T>();
        for (T node : nodes) {
            Set<T> dests = edges.get(node);
            if (dests != null) {
                for (T dst : dests) {
                    if (visited.contains(dst)) {
                        return Pair.of(node, dst);
                    }
                }
            }
            visited.add(node);
        }
        return null;
    }
}
