package ru.yandex.crypta.graph.api.service.transformer;

import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import org.jgrapht.Graphs;
import org.jgrapht.alg.interfaces.SpanningTreeAlgorithm;
import org.jgrapht.alg.spanning.KruskalMinimumSpanningTree;
import org.jgrapht.graph.WeightedMultigraph;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.crypta.graph.api.model.graph.Edge;
import ru.yandex.crypta.graph.api.model.graph.Vertex;

public class SpanningTreeTransformer implements GraphTransformer {


    private Set<Vertex> keepVertices;
    private Set<Edge> keepEdges;

    public SpanningTreeTransformer(Set<Vertex> keepVertices,
                                   Set<Edge> keepEdges) {
        this.keepVertices = keepVertices;
        this.keepEdges = keepEdges;
    }

    public SpanningTreeTransformer() {
        this(Cf.set(), Cf.set());
    }

    @Override
    public WeightedMultigraph<Vertex, Edge> transform(WeightedMultigraph<Vertex, Edge> graph) {

        SpanningTreeAlgorithm<Edge> sta = new KruskalMinimumSpanningTree<>(graph);

        Set<Edge> spanningEdges = sta.getSpanningTree().getEdges();
        spanningEdges.addAll(keepEdges);
        WeightedMultigraph<Vertex, Edge> spanningGraph = JGraphTHelper.toWeightedGraph(spanningEdges);

        if (keepVertices.isEmpty()) {
            return spanningGraph;
        } else {
            Set<Vertex> removableVertices = getCornerVertices(spanningGraph, keepVertices);
            List<Edge> filteredEdges = spanningEdges.stream().filter(edge ->
                    !removableVertices.contains(edge.getVertex1())
                            && !removableVertices.contains(edge.getVertex2()))
                    .collect(Collectors.toList());

            return JGraphTHelper.toWeightedGraph(filteredEdges);
        }
    }

    private Set<Vertex> getCornerVertices(WeightedMultigraph<Vertex, Edge> newGraph, Set<Vertex> keepVertices) {
        Map<Vertex, Integer> vertexDegrees = new HashMap<>();
        for (Edge newEdge : newGraph.edgeSet()) {
            Vertex v1 = newEdge.getVertex1();
            Vertex v2 = newEdge.getVertex2();
            vertexDegrees.compute(v1, (v, degree) -> (degree == null) ? 1 : degree + 1);
            vertexDegrees.compute(v2, (v, degree) -> (degree == null) ? 1 : degree + 1);
        }

        Set<Vertex> removableVertices = new HashSet<>();
        LinkedList<Vertex> queue = new LinkedList<>();

        for (Vertex vertex : vertexDegrees.keySet()) {
            if (vertexDegrees.get(vertex) <= 1 && !keepVertices.contains(vertex)) {
                queue.push(vertex);
            }
        }

        while (!queue.isEmpty()) {
            Vertex vertex = queue.pop();
            removableVertices.add(vertex);


            for (Edge edge : newGraph.edgesOf(vertex)) {
                Vertex targetVertex = Graphs.getOppositeVertex(newGraph, edge, vertex);
                int degree = vertexDegrees.compute(targetVertex, (v, d) -> d == null ? 0 : d - 1);
                if (degree <= 1 && !keepVertices.contains(targetVertex) && !removableVertices
                        .contains(targetVertex)) {
                    queue.push(targetVertex);
                }
            }
        }
        return removableVertices;
    }


}
