package ru.yandex.solomon.gateway.api.cloud.v2;

import java.time.Instant;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import javax.annotation.Nullable;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;

import ru.yandex.monlib.metrics.registry.MetricId;
import ru.yandex.solomon.common.RequestProducer;
import ru.yandex.solomon.labels.query.Selectors;
import ru.yandex.solomon.metrics.client.FindRequest;
import ru.yandex.solomon.metrics.client.FindResponse;
import ru.yandex.solomon.metrics.client.MetabaseClientException;
import ru.yandex.solomon.metrics.client.MetabaseStatus;
import ru.yandex.solomon.metrics.client.MetricsClient;
import ru.yandex.solomon.metrics.client.cache.TimeExpiredItem;
import ru.yandex.solomon.model.MetricKey;

import static java.util.concurrent.CompletableFuture.completedFuture;

/**
 * @author Vladimir Gordiychuk
 */
public class PagedMetaLoader {
    private static final long EXPIRE_TIME_MILLIS = TimeUnit.MINUTES.toMillis(10);
    private static final long REMOVE_AFTER_ACCESS_MILLIS = TimeUnit.MINUTES.toMillis(15L);

    private static final int METRICS_PAGE_LIMIT = 10000;
    private final MetricsClient client;
    private volatile long slowResponseAwaitMillis = 3_000;
    private final Cache<CacheKey, TimeExpiredItem<List<MetricKey>>> cached;
    private final ConcurrentMap<CacheKey, CompletableFuture<List<MetricKey>>> activeFind = new ConcurrentHashMap<>();

    public PagedMetaLoader(MetricsClient client) {
        this.client = client;
        this.cached = CacheBuilder.newBuilder()
                .expireAfterAccess(REMOVE_AFTER_ACCESS_MILLIS, TimeUnit.MILLISECONDS)
                .build();
    }

    public CompletableFuture<List<MetricKey>> find(Selectors selectors, Instant deadline, RequestProducer producer) {
        var futures = client.getDestinations().stream()
                .map(dest -> {
                    var key = new CacheKey(dest, selectors);
                    var cachedResponse = cached.getIfPresent(key);
                    if (cachedResponse == null) {
                        return dcLoad(key, deadline, producer);
                    }

                    if (cachedResponse.isExpired(System.currentTimeMillis())) {
                        // avoid wait, warm cache only
                        dcLoad(key, deadline, producer);
                    }

                    return completedFuture(cachedResponse.getPayload());
                })
                .collect(Collectors.toList());

        var ready = futures.stream()
                .filter(future -> future.isDone() && !future.isCompletedExceptionally())
                .map(future -> future.getNow(null))
                .collect(Collectors.toList());

        // don't wait second replica if it's not ready yet
        if (!ready.isEmpty()) {
            return completedFuture(mergeCrossDcMetrics(ready));
        }

        return anyOrAllOf(futures, slowResponseAwaitMillis)
                .thenApply(this::mergeCrossDcMetrics);
    }

    private CompletableFuture<List<MetricKey>> dcLoad(CacheKey key, Instant deadline, RequestProducer producer) {
        var prev = activeFind.get(key);
        if (prev != null) {
            return prev;
        }

        var future = activeFind.computeIfAbsent(key, k -> new DcLoader(k.selectors, k.dest, deadline, producer).load());
        return future.whenComplete((response, e) -> {
            if (response != null) {
                cached.put(key, new TimeExpiredItem<>(response, EXPIRE_TIME_MILLIS, System.currentTimeMillis()));
            }

            activeFind.remove(key, future);
        });
    }

    static CompletableFuture<List<List<MetricKey>>> anyOrAllOf(List<CompletableFuture<List<MetricKey>>> futures, long maxAwaitTimeMillis) {
        if (futures.isEmpty()) {
            return CompletableFuture.completedFuture(List.of());
        }

        if (futures.size() == 1) {
            return futures.get(0).thenApply(keys -> List.of(keys));
        }

        CompletableFuture<Void> doneFuture = new CompletableFuture<>();
        AtomicInteger done = new AtomicInteger(futures.size());
        for (var future : futures) {
            future.whenComplete((ignore, e) -> {
                if (done.decrementAndGet() != 0) {
                    if (e != null) {
                        doneFuture.completeOnTimeout(null, maxAwaitTimeMillis, TimeUnit.MILLISECONDS);
                    }
                } else {
                    doneFuture.complete(null);
                }
            });
        }
        return doneFuture.thenCompose(ignore -> {
            List<List<MetricKey>> result = new ArrayList<>();
            boolean anySuccess = false;
            for (var future : futures) {
                if (future.isDone() && !future.isCompletedExceptionally()) {
                    anySuccess = true;
                    result.add(future.getNow(List.of()));
                }
            }

            if (anySuccess) {
                return CompletableFuture.completedFuture(result);
            }

            return futures.get(0).thenApply(keys -> List.of(keys));
        });
    }


    private List<MetricKey> mergeCrossDcMetrics(List<List<MetricKey>> responses) {
        if (responses.isEmpty()) {
            return List.of();
        }

        if (responses.size() == 1) {
            return responses.get(0);
        }

        int approximateSize = responses.get(0).size();
        Map<MetricId, MetricKey> result = new HashMap<>(approximateSize);

        for (var response : responses) {
            for (MetricKey metricKey : response) {
                MetricId metricId = new MetricId(metricKey.getName(), metricKey.getLabels());
                MetricKey prevMetricKey = result.get(metricId);
                if (prevMetricKey == null) {
                    result.put(metricId, metricKey);
                } else {
                    result.put(metricId, prevMetricKey.combine(metricKey));
                }
            }
        }

        return List.copyOf(result.values());
    }

    private class DcLoader {
        private final Selectors selectors;
        private final String dest;
        private final Instant deadline;
        private final RequestProducer producer;
        private final CompletableFuture<List<MetricKey>> doneFuture = new CompletableFuture<>();
        private final List<MetricKey> result = new ArrayList<>();
        private int offset;

        private DcLoader(Selectors selectors, String dest, Instant deadline, RequestProducer producer) {
            this.selectors = selectors;
            this.dest = dest;
            this.deadline = deadline;
            this.producer = producer;
        }

        public CompletableFuture<List<MetricKey>> load() {
            loadNextPage();
            return doneFuture;
        }

        private void loadNextPage() {
            FindRequest req = FindRequest.newBuilder()
                    .setSelectors(selectors)
                    .setUseNewFormat(true)
                    .setDestination(dest)
                    .setLimit(METRICS_PAGE_LIMIT)
                    .setOffset(offset)
                    .setDeadline(deadline)
                    .setProducer(producer)
                    .build();

            client.find(req).whenComplete(this::onPageLoad);
        }

        private void onPageLoad(FindResponse resp, @Nullable Throwable e) {
            if (e != null) {
                doneFuture.completeExceptionally(e);
                return;
            }

            if (resp.getStatus() != MetabaseStatus.OK) {
                doneFuture.completeExceptionally(new MetabaseClientException(resp.getStatus()));
                return;
            }

            result.addAll(resp.getMetrics());
            offset += METRICS_PAGE_LIMIT;
            if (resp.isTruncated() && offset < resp.getTotalCount()) {
                loadNextPage();
            } else {
                doneFuture.complete(result);
            }
        }
    }

    private record CacheKey(String dest, Selectors selectors) {
    }
}
