package ru.yandex.crypta.graph.api.service;

import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;

import javax.inject.Inject;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.crypta.common.exception.Exceptions;
import ru.yandex.crypta.graph.api.model.graph.Edge;
import ru.yandex.crypta.graph.api.model.graph.Graph;
import ru.yandex.crypta.graph.api.model.graph.GraphComponent;
import ru.yandex.crypta.graph.api.model.graph.Vertex;
import ru.yandex.crypta.graph.api.model.ids.GraphId;
import ru.yandex.crypta.graph.api.service.settings.SoupGraphSettings;
import ru.yandex.crypta.graph.api.service.settings.YtSoupGraphSettings;
import ru.yandex.crypta.graph.api.service.settings.model.InfoParams;
import ru.yandex.crypta.graph.api.service.settings.model.SearchParams;
import ru.yandex.crypta.graph.soup.config.Soup;
import ru.yandex.crypta.graph.soup.config.proto.TEdgeType;
import ru.yandex.crypta.graph2.dao.yt.utils.YTreeUtils;
import ru.yandex.crypta.lib.yt.YtService;
import ru.yandex.inside.yt.kosher.cypress.YPath;
import ru.yandex.inside.yt.kosher.ytree.YTreeMapNode;
import ru.yandex.yt.ytclient.proxy.YtClient;
import ru.yandex.yt.ytclient.rpc.RpcError;
import ru.yandex.yt.ytclient.wire.UnversionedRowset;

import static java.util.stream.Collectors.toList;

public class YtSoupGraphService implements GraphService {

    private static final Logger LOG = LoggerFactory.getLogger(YtSoupGraphService.class);

    private static final int VERTICES_SAFE_DRAW_LIMIT = 10_000;
    private static final int YT_INPUT_ROW_LIMIT = 1_000_000;

    private final SoupGraphSettings graphSettings;

    private YtService yt;

    @Inject
    public YtSoupGraphService(YtService yt, YtSoupGraphSettings graphSettings) {
        this(yt, (SoupGraphSettings) graphSettings);
    }

    protected YtSoupGraphService(YtService yt, SoupGraphSettings graphSettings) {
        this.graphSettings = graphSettings;
        this.yt = yt;
    }

    public static Edge mapSoupRecToEdge(YTreeMapNode rec) {
        String id1 = rec.getString("id1");
        String id2 = rec.getString("id2");
        String id1Type = rec.getString("id1Type");
        String id2Type = rec.getString("id2Type");
        String sourceType = rec.getString("sourceType");
        String logSource = rec.getString("logSource");
        List<String> dates = YTreeUtils.getStringList(rec, "dates");

        // some edges may have explicit reversed flag
        boolean explicitReversed = rec.getBoolO("reversed").orElse(false);
        // some edges can be marked as reversed if config if found for reversed type
        TEdgeType reversedEdgeType = Soup.CONFIG.tryGetEdgeType(id2Type, id1Type, sourceType, logSource);

        boolean reversed = explicitReversed || reversedEdgeType != null;

        if (reversed) {
            return new Edge(
                    id2, id2Type,
                    id1, id1Type,
                    sourceType, logSource,
                    rec.getDoubleO("datesWeight").orElse(0.0), dates
            );
        } else {
            return new Edge(
                    id1, id1Type,
                    id2, id2Type,
                    sourceType, logSource,
                    rec.getDoubleO("datesWeight").orElse(0.0), dates
            );
        }

    }

    private CompletableFuture<UnversionedRowset> selectRowsForVertex(YtClient client, YPath table, Vertex vertex,
                                                                     int limit) {
        return client.selectRows(String.format(
                "* FROM [%s] WHERE id1 = '%s' AND id1Type = '%s' LIMIT %d",
                table, vertex.getIdValue(), vertex.getIdType(), limit
        ));
    }

    private CompletableFuture<Integer> selectRowsCountForVertex(YtClient client, YPath table, Vertex vertex) {
        return client.selectRows(String.format(
                "sum(1) as neighbours_count FROM [%s] WHERE id1 = '%s' AND id1Type = '%s' GROUP BY 1",
                table, vertex.getIdValue(), vertex.getIdType()
        )).thenApply(rowset ->
                rowset.getYTreeRows(true).get(0).getInt("neighbours_count")
        ).exceptionally(ex -> {
            if (ex.getCause() instanceof RpcError) {
                LOG.warn("Too large vertex is encountered. {} has more than {} adjacent edges", vertex,
                        YT_INPUT_ROW_LIMIT);
                return YT_INPUT_ROW_LIMIT;
            } else {
                throw new CompletionException(ex.getCause());
            }
        });
    }

    private void reportLevelTime(Instant start, String levelN, int verticesCount) {
        Instant finish = Instant.now();
        long elapsedTime = Duration.between(start, finish).toMillis();
        LOG.info(
                "Read {} times at {} took {}ms, {}ms per call",
                verticesCount,
                levelN,
                elapsedTime,
                (double) elapsedTime / verticesCount
        );
    }

    @Override
    public SoupGraphSettings getGraphSettings() {
        return graphSettings;
    }

    @Override
    public Optional<Graph> getById(GraphId id, SearchParams params, InfoParams infoParams) {
        YPath table = graphSettings.getSoupTable(params.getMatchingScope());
        int nearestEdgesLimit = graphSettings.getEdgesLimit();
        int depth = params.getDepth() == null ? 1 : params.getDepth();
        return getById(id, table, nearestEdgesLimit, depth);
    }

    public Optional<Graph> getById(GraphId id, YPath table, int nearestEdgesLimit, int depth) {
        Instant startTotal = Instant.now();

        YtClient client = yt.getHahnRpc();

        // vertex -> overlimit size. Using map instead of set to atomically update overlimit size when it's found
        Map<Vertex, Integer> vertexSet = new ConcurrentHashMap<>();
        Set<Edge> edgeSet = ConcurrentHashMap.newKeySet();

        Set<Vertex> currentLevel = ConcurrentHashMap.newKeySet();

        Vertex rootVertex = new Vertex(id.getIdValue(), id.getIdType());
        currentLevel.add(rootVertex);
        vertexSet.put(rootVertex, 0);

        int levelN = 1;
        int totalReads = 0;

        while (levelN <= depth) {
            Set<Vertex> nextLevel = ConcurrentHashMap.newKeySet();

            List<CompletableFuture> futures = new ArrayList<>();

            int readsN = currentLevel.size();
            LOG.info("Reading level {}(size={}) for vertex {}({})",
                    levelN, readsN,
                    rootVertex.getIdValue(), rootVertex.getIdType()
            );

            Instant start = Instant.now();

            for (Vertex vertex : currentLevel) {

                CompletableFuture<Void> future = selectRowsForVertex(client, table, vertex, nearestEdgesLimit)
                        .thenAccept(rowSet -> {
                            // thenAccept vs thenAcceptAsync
                            // this op is non blocking, thus running in the same future thread using thenAccept

                            List<YTreeMapNode> recs = rowSet.getYTreeRows(true);

                            // Only read adjacent vertices if their number is less than limit
                            if (recs.size() < nearestEdgesLimit) {
                                for (YTreeMapNode rec : recs) {

                                    Vertex child = new Vertex(
                                            rec.getString("id2"),
                                            rec.getString("id2Type")
                                    );

                                    if (!vertexSet.containsKey(child)) {
                                        nextLevel.add(child);
                                    }

                                    vertexSet.putIfAbsent(child, 0);

                                    if (vertexSet.size() > VERTICES_SAFE_DRAW_LIMIT) {
                                        // to avoid OOM
                                        throw Exceptions.illegal("Soup area is too large, can't draw");
                                    }

                                    edgeSet.add(mapSoupRecToEdge(rec));

                                }
                            } else {
                                int neighboursCount = selectRowsCountForVertex(client, table, vertex).join();

                                vertexSet.put(vertex, neighboursCount);
                            }
                        });

                futures.add(future);

            }

            // wait for a level to execute
            CompletableFuture combinedFuture = CompletableFuture.allOf(futures.toArray(new CompletableFuture[]{}));
            try {
                combinedFuture.get();
            } catch (InterruptedException | ExecutionException e) {
                LOG.error("Can't get soup edges", e.getCause());
                throw Exceptions.unchecked(e.getCause());
            }

            reportLevelTime(start, levelN + " level", readsN);
            totalReads += readsN;

            currentLevel = nextLevel;
            levelN++;
        }

        reportLevelTime(startTotal, "all levels", totalReads);

        List<GraphComponent> singleComponent = new ArrayList<>();

        List<Vertex> verticesWithOverlimitSize = vertexSet.entrySet().stream().map(entry -> {
                    Vertex vertex = entry.getKey();
                    vertex.setOverlimitCount(entry.getValue());
                    return vertex;
                }
        ).collect(toList());

        singleComponent.add(new GraphComponent(
                null,
                verticesWithOverlimitSize,
                new ArrayList<>(edgeSet)
        ));

        return Optional.of(new Graph(singleComponent));
    }

}
