#include "split.h"
#include "algo.h"
#include <util/generic/hash_set.h>

namespace NCrypta::NGraphEngine {

    TVector<THashSet<ui64>> ComputeInseparableSets(const NInnerView::TGraph& graph) {
        auto twoConnectedComponents = NAlgo::FindTwoEdgeConnectedComponents(graph);
        const auto numComponents = *MaxElement(twoConnectedComponents.begin(), twoConnectedComponents.end()) + 1;
        TVector<THashSet<ui64>> inseparableSets(numComponents);
        for (ui64 i = 0; i < graph.Vertices.size(); ++i) {
            inseparableSets[twoConnectedComponents[i]].insert(i);
        }
        return inseparableSets;
    }

    void JoinInseparableComponents(const NInnerView::TGraph& graph, TDisjointSets& components) {
        auto inseparableSets = ComputeInseparableSets(graph);
        for (const auto& inseparableSet : inseparableSets) {
            if (inseparableSet.size() > 1) {
                ui64 label = *inseparableSet.begin();
                for (ui64 element : inseparableSet) {
                    components.UnionSets(label, element);
                }
            }
        }
    }

    TVector<NInnerView::TEdge> GetOrderedEdges(const NInnerView::TGraph& graph) {
        TVector<NInnerView::TEdge> result;
        for (const auto& edges : graph.Edges) {
            for (const auto& edge : edges) {
                if (edge.Vertex1 < edge.Vertex2) {
                    result.push_back(edge);
                }
            }
        }
        Sort(result.begin(), result.end(), [](const NInnerView::TEdge& first, const NInnerView::TEdge& second) {
            return first.Weight > second.Weight;
        });
        return result;
    }

    void JoinComponentsWithStrongEdgesBetween(const NInnerView::TGraph& graph, TDisjointSets& components,
                                              double threshold = 0.4) {
        for (const auto& edge : GetOrderedEdges(graph)) {
            if (edge.IsStrong || (edge.Weight > threshold && components.SetCount() > 2)) {
                components.UnionSets(edge.Vertex1, edge.Vertex2);
            }
        }
    }

    void JoinLeaves(const NInnerView::TGraph& graph, TDisjointSets& components) {
        for (const auto& edges : graph.Edges) {
            if (edges.size() == 1) {
                const auto& edge = edges.front();
                components.UnionSets(edge.Vertex1, edge.Vertex2);
            }
        }
    }

    void JoinSmallComponents(const NInnerView::TGraph& graph, TDisjointSets& components, ui64 minSingleSize = 2) {
        TVector<TVector<ui64>> componentEdges(graph.Vertices.size());

        for (const auto& edges : graph.Edges) {
            for (const auto& edge : edges) {
                ui64 component1 = components.CanonicSetElement(edge.Vertex1);
                ui64 component2 = components.CanonicSetElement(edge.Vertex2);
                componentEdges[component1].push_back(component2);
                componentEdges[component2].push_back(component1);
            }
        }

        for (ui64 component1 = 0; component1 < componentEdges.size(); component1++) {
            if (components.SizeOfSet(component1) > minSingleSize || !componentEdges[component1]) {
                continue;
            }
            THashSet<ui64> candidates;
            for (ui64 neighbour : componentEdges[component1]) {
                candidates.insert(components.CanonicSetElement(neighbour));
                if (candidates.size() > 1) {
                    break;
                }
            }
            if (candidates.size() == 1) {
                components.UnionSets(component1, *(candidates.begin()));
            }
        }
    }

    void JoinSmallComponentsBetweenBigComponents(const NInnerView::TGraph& graph, TDisjointSets& components, ui64 minSingleSize) {
        double INF = 1;
        //weight of edges less or equal 1, and we only multiply them weights, so we can't get something larger than 1

        TSet<TDijkstraSetElement> setForDijkstra;
        //we are going to count distance from each vertex to the closest big component.
        // The algorithm is a Dijkstra algorithm, where initially the vertices of the large components, firstly, are taken with a neutral relative to the function (we have this multiplication)

        THashMap<ui64, TDijkstraSetElement> toRemove;
        //elements from setForDijkstra looks like {current_distance_from_closest_big_component, index_of_vertex, canonic_element_of_closest_big_component}
        TVector<double> distanceFromClosestBigComponent(graph.Vertices.size(), INF);

        for (ui64 vertexIndex = 0; vertexIndex < graph.Vertices.size(); ++vertexIndex) {
            if (components.SizeOfSet(vertexIndex) >= minSingleSize) {
                setForDijkstra.insert({-INF, vertexIndex, components.CanonicSetElement(vertexIndex)});
                distanceFromClosestBigComponent[vertexIndex] = -INF;
                toRemove[vertexIndex] = {-INF, vertexIndex, components.CanonicSetElement(vertexIndex)};
            }
        }
        // The algorithm is a Dijkstra algorithm, where initially the vertices of the big components, firstly, are taken with a neutral relative to the function (we have this multiplication) weights.
        // we also want these vertices to be considered first
        // Therefore, the distance function was taken as : minus product of the weights of the edges


        while (!setForDijkstra.empty()) {
            auto vertexIndex = setForDijkstra.begin()->vertexIndex;
            auto parentVertexIndex = setForDijkstra.begin()->parentVertexIndex;
            auto distance = setForDijkstra.begin()->distance;
            setForDijkstra.erase(setForDijkstra.begin());

            if (components.SizeOfSet(vertexIndex) < minSingleSize) {
                components.UnionSets(vertexIndex, parentVertexIndex);
            }
            for (const auto& edge : graph.Edges[vertexIndex]) {
                auto vertexTo = edge.Vertex2;
                if (distance * edge.Weight < distanceFromClosestBigComponent[vertexTo]) {
                    setForDijkstra.erase(toRemove[vertexTo]);
                    distanceFromClosestBigComponent[vertexTo] = distance * edge.Weight;
                    setForDijkstra.insert({distanceFromClosestBigComponent[vertexTo], vertexTo, parentVertexIndex});
                    toRemove[vertexTo] = {distanceFromClosestBigComponent[vertexTo], vertexTo, parentVertexIndex};
                }
            }
        }
    }

    TVector<ui64> ConvertComponentsToSplit(int size, const TDisjointSets& components) {
        TVector<ui64> result(size);
        THashMap<ui64, ui64> componentIndexes(components.SetCount());

        for (ui64 i = 0; i < result.size(); ++i) {
            ui64 index = componentIndexes.size();
            auto label = components.CanonicSetElement(i);
            auto componentIndexIt = componentIndexes.find(label);
            if (componentIndexIt != componentIndexes.end()) {
                index = componentIndexIt->second;
            } else {
                componentIndexes[label] = index;
            }
            result[i] = index;
        }
        return result;
    }

    TVector<ui64> Split(const NInnerView::TGraph& graph) {
        TDisjointSets components(graph.Vertices.size());
        JoinLeaves(graph, components);
        JoinInseparableComponents(graph, components);
        JoinComponentsWithStrongEdgesBetween(graph, components);
        //JoinSmallComponents(graph, components);
        JoinSmallComponentsBetweenBigComponents(graph, components);

        return ConvertComponentsToSplit(graph.Vertices.size(), components);
    }

    TCustomSplit Split(const TCommonGraph& graph) {
        auto split = Split(graph.GetInnerGraph());
        const auto splitSize = *MaxElement(split.begin(), split.end()) + 1;
        TVector<TVector<NCustomView::TVertex>> customSplit(splitSize);

        for (ui64 vertex = 0; vertex < split.size(); ++vertex) {
            auto label = split[vertex];
            customSplit[label].push_back(graph.ConvertToCustomView(vertex));
        }
        return TCustomSplit{customSplit};
    }
}
