package ru.yandex.crypta.graph.api.service.transformer;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.jgrapht.Graphs;
import org.jgrapht.graph.WeightedMultigraph;

import ru.yandex.crypta.graph.api.model.graph.Edge;
import ru.yandex.crypta.graph.api.model.graph.Vertex;
import ru.yandex.crypta.graph.soup.config.Soup;
import ru.yandex.crypta.graph.soup.config.proto.ELogSourceType;
import ru.yandex.crypta.graph.soup.config.proto.ESourceType;
import ru.yandex.crypta.graph.soup.config.proto.TEdgeProps;
import ru.yandex.crypta.graph.soup.config.proto.TEdgeType;
import ru.yandex.crypta.lib.proto.identifiers.EIdType;


public class GraphShrinkingTransformer implements GraphTransformer {

    public static final List<String> MATCH_SCOPES = Stream
            .of(ShrinkType.values())
            .map(ShrinkType::getMatchScope)
            .collect(Collectors.toList());

    private Predicate<Edge> condition;
    private String name = "shrunk";

    public GraphShrinkingTransformer(ShrinkType shrinkType) {
        condition = shrinkType.condition;
    }

    public GraphShrinkingTransformer(Predicate<Edge> condition) {
        this.condition = condition;
    }

    public GraphShrinkingTransformer(Predicate<Edge> condition, String name) {
        this.name = name;
        this.condition = condition;
    }

    private WeightedMultigraph<Vertex, Edge> constructNewGraph(WeightedMultigraph<Vertex, Edge> graph, Map<Vertex,
            String> shrunkVertices) {
        final Set<Edge> finalNewEdges = new HashSet<>();
        Set<Edge> edges = graph.edgeSet();
        edges.forEach(edge -> {

            Vertex v1 = edge.getVertex1();
            Vertex v2 = edge.getVertex2();
            String sourceType = edge.getSourceType();
            String logSource = edge.getLogSource();
            Double weight = edge.getWeight();
            List<String> dates = edge.getDates();

            Vertex newV1 = new Vertex(shrunkVertices.get(v1), this.name);
            Vertex newV2 = new Vertex(shrunkVertices.get(v2), this.name);

            if (shrunkVertices.containsKey(v1) && !shrunkVertices.containsKey(v2)) {
                finalNewEdges.add(new Edge(newV1, v2, sourceType, logSource, weight, dates));
            } else {
                if (!shrunkVertices.containsKey(v1) && shrunkVertices.containsKey(v2)) {
                    finalNewEdges.add(new Edge(v1, newV2, sourceType, logSource, weight, dates));
                } else {
                    if (!shrunkVertices.containsKey(v1) && !shrunkVertices.containsKey(v2)) {
                        finalNewEdges.add(edge);
                    } else {
                        if (!this.condition.test(edge) && !newV1.equals(newV2)) {
                            finalNewEdges.add(new Edge(newV1, newV2, sourceType, logSource, weight, dates));
                        }
                    }
                    // remaining case is when both vertices are shrunk, so we want to skip them
                }
            }
        });


        if (finalNewEdges.isEmpty()) {
            Vertex v = new Vertex(shrunkVertices.values().iterator().next(), this.name);
            return JGraphTHelper.toOneVertexGraph(v);
        }

        return JGraphTHelper.toSimpleGraph(finalNewEdges);
    }

    private boolean HasLabeledEdge(Vertex v, WeightedMultigraph<Vertex, Edge> graph) {
        Set<Edge> edges = graph.edgesOf(v);
        for (Edge edge : edges) {
            if (this.condition.test(edge)) {
                return true;
            }
        }
        return false;
    }

    private HashMap<Vertex, String> renameShrunkVertices(ArrayList<Vertex> vertices) {
        StringBuilder newName = new StringBuilder();
        for (int i = 0; i < vertices.size(); i++) {
            Vertex v = vertices.get(i);
            if (i == 0) {
                newName.append(v.getIdValue()).append("(").append(v.getIdType()).append(")");
            } else {
                newName.append("|").append(v.getIdValue()).append("(").append(v.getIdType()).append(")");
            }
        }

        HashMap<Vertex, String> renamedShrunkVertices = new HashMap<>();
        for (var v : vertices) {
            renamedShrunkVertices.put(v, newName.toString());
        }
        return renamedShrunkVertices;
    }

    private ArrayList<Vertex> getShrunkVerticesCluster(Vertex vertex, WeightedMultigraph<Vertex, Edge> graph,
                                                       Set<Vertex> used) {
        ArrayList<Vertex> shrunkVertices = new ArrayList<>();
        Queue<Vertex> queue = new LinkedList<>();
        queue.add(vertex);

        while (!queue.isEmpty()) {
            Vertex v = queue.poll();
            if (used.contains(v)) {
                continue;
            }
            List<Vertex> notUsedNeighbours =
                    Graphs.neighborListOf(graph, v).stream().filter(x -> !used.contains(x)).collect(Collectors.toList());
            List<Vertex> labeledNeighbours =
                    notUsedNeighbours.stream().filter(x -> this.condition.test(graph.getEdge(x, v))).collect(Collectors.toList());
            queue.addAll(labeledNeighbours);
            shrunkVertices.addAll(labeledNeighbours);
            shrunkVertices.add(v);
            used.add(v);
        }

        return shrunkVertices;
    }

    @Override
    public WeightedMultigraph<Vertex, Edge> transform(WeightedMultigraph<Vertex, Edge> graph) {

        Map<Vertex, String> shrunkVertices = new HashMap<>();
        Set<Vertex> allVertices = graph.vertexSet();
        Set<Vertex> used = new HashSet<>();

        for (Vertex vertex : allVertices) {
            if (used.contains(vertex)) {
                continue;
            }
            if (!HasLabeledEdge(vertex, graph)) {
                used.add(vertex);
                continue;
            }
            ArrayList<Vertex> shrunkVerticesCluster = getShrunkVerticesCluster(vertex, graph, used);
            HashMap<Vertex, String> renamedVertices = renameShrunkVertices(shrunkVerticesCluster);
            shrunkVertices.putAll(renamedVertices);
        }

        return constructNewGraph(graph, shrunkVertices);
    }

    private static boolean additionalCheckForShrinking(Edge edge) {
        // https://st.yandex-team.ru/CRYPTR-1596#5d9f56dca2b79e001d169cd1

        String id1 = edge.getId1();
        EIdType id1Type = Soup.CONFIG.getIdType(edge.getId1Type()).getType();
        EIdType id2Type = Soup.CONFIG.getIdType(edge.getId2Type()).getType();
        ESourceType sourceType = Soup.CONFIG.getSourceType(edge.getSourceType()).getType();
        ELogSourceType logSource = Soup.CONFIG.getLogSource(edge.getLogSource()).getType();

        boolean state1 = id1Type.equals(EIdType.YANDEXUID) &&
                id2Type.equals(EIdType.DISTR_R1) &&
                sourceType.equals(ESourceType.YASOFT) &&
                logSource.equals(ELogSourceType.BAR_NAVIG_LOG);

        boolean state2 = id1Type.equals(EIdType.MM_DEVICE_ID) &&
                id2Type.equals(EIdType.MAC) &&
                sourceType.equals(ESourceType.APP_METRICA) &&
                logSource.equals(ELogSourceType.METRIKA_MOBILE_LOG);


        boolean state3 = id1Type.equals(EIdType.PUID) &&
                id2Type.equals(EIdType.LOGIN) &&
                sourceType.equals(ESourceType.PASSPORT_PROFILE) &&
                logSource.equals(ELogSourceType.PASSPORT_DICT) &&
                id1.contains("phne-");

        boolean state4 = id1Type.equals(EIdType.PUID) &&
                id2Type.equals(EIdType.PHONE_MD5) &&
                sourceType.equals(ESourceType.PASSPORT_PROFILE) &&
                logSource.equals(ELogSourceType.PASSPORT_PHONE_DUMP) &&
                id1.contains("phne-");

        boolean state5 = id1Type.equals(EIdType.YANDEXUID) &&
                id2Type.equals(EIdType.YANDEXUID) &&
                sourceType.equals(ESourceType.XUNIQ) &&
                logSource.equals(ELogSourceType.BS_XUNIQS_LOG);

        boolean state6 = id1Type.equals(EIdType.YANDEXUID) &&
                id2Type.equals(EIdType.UUID) &&
                sourceType.equals(ESourceType.XUNIQ) &&
                logSource.equals(ELogSourceType.BS_XUNIQS_LOG);

        boolean state7 = id1Type.equals(EIdType.YANDEXUID) &&
                id2Type.equals(EIdType.XUNIQ_GUID) &&
                sourceType.equals(ESourceType.XUNIQ) &&
                logSource.equals(ELogSourceType.BS_XUNIQS_LOG);

        return state1 || state2 || state3 || state4 || state5 || state6 || state7;
    }


    public enum ShrinkType {
        ARTIFICIAL("shrinking_artificial", (Edge e) -> {
            TEdgeType edgeType = Soup.CONFIG.getEdgeType(
                    e.getId1Type(), e.getId2Type(),
                    e.getSourceType(), e.getLogSource()
            );
            return Soup.CONFIG.getEdgeProps(edgeType).getEdgeStrength() == TEdgeProps.EEdgeStrength.ARTIFICIAL;
        }),
        INDEVICE("shrinking_indevice", (Edge e) -> {
            TEdgeType edgeType = Soup.CONFIG.getEdgeType(
                    e.getId1Type(), e.getId2Type(),
                    e.getSourceType(), e.getLogSource()
            );
            return Soup.CONFIG.getEdgeUsage(edgeType).getSoupyIndevice();
        }),
        ARTIFICIAL_INDEVICE("shrinking_artificial_n_indevice", (Edge e) -> {
            TEdgeType edgeType = Soup.CONFIG.getEdgeType(
                    e.getId1Type(), e.getId2Type(),
                    e.getSourceType(), e.getLogSource()
            );
            return Soup.CONFIG.getEdgeProps(edgeType).getEdgeStrength() == TEdgeProps.EEdgeStrength.ARTIFICIAL
                    || Soup.CONFIG.getEdgeUsage(edgeType).getSoupyIndevice();
        }),
        ARTIFICIAL_INDEVICE_PLUS("shrinking_artificial_n_indevice_plus", (Edge e) -> {
            TEdgeType edgeType = Soup.CONFIG.getEdgeType(
                    e.getId1Type(), e.getId2Type(),
                    e.getSourceType(), e.getLogSource()
            );
            return Soup.CONFIG.getEdgeProps(edgeType).getEdgeStrength() == TEdgeProps.EEdgeStrength.ARTIFICIAL
                    || Soup.CONFIG.getEdgeUsage(edgeType).getSoupyIndevice() || additionalCheckForShrinking(e);
        });

        private final String matchScope;
        private final Predicate<Edge> condition;

        ShrinkType(String matchScope, Predicate<Edge> condition) {
            this.matchScope = matchScope;
            this.condition = condition;
        }

        public String getMatchScope() {
            return matchScope;
        }
    }
}
