package ru.yandex.crypta.graph2.model.matching.merge.algo.split;


import java.util.Collections;
import java.util.Comparator;
import java.util.Set;

import org.jgrapht.alg.ConnectivityInspector;
import org.jgrapht.graph.DefaultListenableGraph;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.bolts.collection.MapF;
import ru.yandex.bolts.collection.SetF;
import ru.yandex.crypta.graph2.model.matching.component.Component;
import ru.yandex.crypta.graph2.model.matching.component.score.ComponentScoringStrategy;
import ru.yandex.crypta.graph2.model.matching.edge.EdgePriority;
import ru.yandex.crypta.graph2.model.matching.graph.JGraphTUtils;
import ru.yandex.crypta.graph2.model.matching.merge.algo.score.WeightedLinkScoringStrategy;
import ru.yandex.crypta.graph2.model.soup.edge.Edge;
import ru.yandex.crypta.graph2.model.soup.edge.weight.EdgeInfoProvider;
import ru.yandex.crypta.graph2.model.soup.vertex.Vertex;


public class MaxMinSplitAlgorithm extends AbstractSplitAlgorithm {

    private static final double THRESHOLD_FOR_MIN_WEIGHT = 1e-7;
    private static final double THRESHOLD_FOR_MIN_NOT_BAD_WEIGHT = 1e-2;

    private static final int MIN_CANDIDATES_COUNT = 8;
    private final EdgePriority edgePriority;
    private final EdgeInfoProvider edgeInfoProvider;

    public MaxMinSplitAlgorithm(ComponentScoringStrategy componentScoringStrategy,
                                EdgeInfoProvider edgeInfoProvider,
                                WeightedLinkScoringStrategy linkScoringStrategy)
    {
        super(componentScoringStrategy, linkScoringStrategy);
        this.edgeInfoProvider = edgeInfoProvider;
        this.edgePriority = new EdgePriority(edgeInfoProvider);
    }

    private boolean isNilWeighted(Edge edge) {
        return edgeInfoProvider.getEdgeWeight(edge) < THRESHOLD_FOR_MIN_WEIGHT;
    }

    private boolean isNotBadWeighted(Edge edge) {
        return edgeInfoProvider.getEdgeWeight(edge) > THRESHOLD_FOR_MIN_NOT_BAD_WEIGHT;
    }


    private ListF<Edge> chooseAllAvailableRemovalCandidates(Component component) {
        EdgeInfoProvider edgeInfoProvider = edgePriority.getEdgeInfoProvider();
        SetF<Edge> removalCandidates = component.getInnerEdges().filter(
                e -> !edgeInfoProvider.isStrong(e)
        );
        return edgePriority.sortEdgesByActivityAsc(removalCandidates);
    }

    @Override
    public ListF<Edge> prepareRemovalCandidates(Component component) {
        EdgeInfoProvider edgeInfoProvider = edgePriority.getEdgeInfoProvider();
        SetF<Edge> removalCandidates = component.getInnerEdges().filter(
                e -> !edgeInfoProvider.isStrong(e)
        );
        SetF<Edge> nilWeightedCandidates = removalCandidates.filter(
                this::isNilWeighted
        );
        if (nilWeightedCandidates.size() > MIN_CANDIDATES_COUNT) {
            return edgePriority.sortEdgesByActivityAsc(nilWeightedCandidates);
        }
        ListF<Edge> edges = edgePriority.sortEdgesByActivityAsc(removalCandidates);
        ListF<Edge> filterEdges = edges.filter(this::isNotBadWeighted);
        if (filterEdges.size() > MIN_CANDIDATES_COUNT) {
            return edgePriority.sortEdgesByActivityAsc(filterEdges);
        }
        return edgePriority.sortEdgesByActivityAsc(removalCandidates);
    }

    @Override
    public ListF<Edge> chooseRemovalEdges(Component component, ListF<Edge> removalCandidates) {
        return chooseRemovalEdgesInner(component, removalCandidates, true);
    }

    private ListF<Edge> chooseRemovalEdgesInner(Component component, ListF<Edge> removalCandidates,
            boolean firstEffort)
    {
        DefaultListenableGraph<Vertex, Edge> graph = new DefaultListenableGraph<>(JGraphTUtils.toGraph(component));
        ConnectivityInspector<Vertex, Edge> cc = new ConnectivityInspector<>(graph);
        graph.addGraphListener(cc);
        for (Edge edgeToRemove : removalCandidates) {
            graph.removeEdge(edgeToRemove);
        }
        ListF<Set<Vertex>> connectedSets = Cf.wrap(cc.connectedSets());
        // need to sort for determenistic work of componentDSU
        Comparator<Set<Vertex>> c = Comparator.<Set<Vertex>>comparingInt(Set::size)
                .thenComparing(set -> Collections.min(set), Vertex.COMPARATOR);
        connectedSets = connectedSets.sorted(c);
        if (firstEffort && !linkScoringStrategy.isCutInteresting(connectedSets)) {
            return chooseRemovalEdgesInner(component, chooseAllAvailableRemovalCandidates(component), false);
        }
        if (connectedSets.size() < 2) {
            return Cf.list();
        }

        ComponentDSU componentDSU = new ComponentDSU(connectedSets);
        return componentDSU.getCut(removalCandidates);
    }

    static private class ComponentDSU {
        MapF<Vertex, Integer> vertex2componentLabels = Cf.hashMap();
        private MapF<Integer, Integer> parentLabels;
        private MapF<Integer, Integer> sizes;
        private Integer mainFirstComponent;
        private Integer mainSecondComponent;

        ComponentDSU(ListF<Set<Vertex>> connectedSets) {
            parentLabels = Cf.hashMap();
            sizes = Cf.hashMap();

            int componentLabel = 0;
            mainFirstComponent = -1;
            mainSecondComponent = -1;
            for (Set<Vertex> connectedComponent : connectedSets) {
                SetF<Vertex> connectedComponentF = Cf.wrap(connectedComponent);
                for (Vertex vertex : connectedComponentF) {
                    vertex2componentLabels.put(vertex, componentLabel);
                }
                sizes.put(componentLabel, connectedComponentF.size());
                if ((mainFirstComponent < 0) || (connectedSets.get(mainFirstComponent).size() < connectedComponentF
                        .size()))
                {
                    mainSecondComponent = mainFirstComponent;
                    mainFirstComponent = componentLabel;
                } else if ((mainSecondComponent < 0) || (connectedSets.get(mainSecondComponent).size()
                        < connectedComponentF.size()))
                {
                    mainSecondComponent = componentLabel;
                }
                componentLabel++;
            }
        }

        ListF<Edge> getCut(ListF<Edge> removalCandidates) {
            ListF<Edge> result = Cf.arrayList();
            for (Edge edge : removalCandidates.sortedBy(e -> -e.getSurvivalWeight().getOrElse(0.))) {
                int sourceLabel = getComponentLabel(vertex2componentLabels.getO(edge.getVertex1()).get());
                int targetLabel = getComponentLabel(vertex2componentLabels.getO(edge.getVertex2()).get());
                mainFirstComponent = getComponentLabel(mainFirstComponent);
                mainSecondComponent = getComponentLabel(mainSecondComponent);
                if ((sourceLabel == mainFirstComponent && targetLabel == mainSecondComponent) ||
                        (sourceLabel == mainSecondComponent && targetLabel == mainFirstComponent))
                {
                    result.add(edge);
                    continue;
                }
                if (mergeComponents(sourceLabel, targetLabel)) {
                    if (tryChangeMainLabel(sourceLabel, result.isNotEmpty())) {
                        result.clear();
                    }
                }
            }
            return result;
        }

        void mergeToFirst(int first, int second) {
            int size = getSize(first) + getSize(second);
            parentLabels.put(second, first);
            sizes.put(first, size);
            sizes.put(second, 0);
        }

        boolean mergeComponents(int first, int second) {
            first = getComponentLabel(first);
            second = getComponentLabel(second);
            if (first == second) {
                return false;
            }
            if (getSize(first) >= getSize(second)) {
                mergeToFirst(first, second);
            } else {
                mergeToFirst(second, first);
            }
            return true;
        }

        boolean tryChangeMainLabel(int label, boolean isMainComponentConnected) {
            mainFirstComponent = getComponentLabel(mainFirstComponent);
            mainSecondComponent = getComponentLabel(mainSecondComponent);
            label = getComponentLabel(label);
            int size = getSize(label);
            if (label == mainFirstComponent || label == mainSecondComponent) {
                return false;
            }
            if (getSize(mainFirstComponent) < size) {
                if (isMainComponentConnected) {
                    mergeToFirst(mainFirstComponent, mainSecondComponent);
                }
                mainSecondComponent = mainFirstComponent;
                mainFirstComponent = label;
                return true;
            }
            if (getSize(mainSecondComponent) < size) {
                if (isMainComponentConnected) {
                    mergeToFirst(mainFirstComponent, mainSecondComponent);
                }
                mainSecondComponent = label;
                return true;
            }
            return false;
        }

        int getSize(int label) {
            return sizes.getOrElse(getComponentLabel(label), 0);
        }

        int getComponentLabel(int label) {
            int parent = parentLabels.getOrElse(label, label);
            if (parent == label) {
                return label;
            }
            parent = getComponentLabel(parent);
            parentLabels.put(label, parent);
            return parent;
        }
    }
}
