package ru.yandex.solomon.metrics.client.cache;

import java.time.Clock;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import javax.annotation.Nullable;
import javax.annotation.ParametersAreNonnullByDefault;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Ticker;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;

import ru.yandex.misc.concurrent.CompletableFutures;
import ru.yandex.monlib.metrics.histogram.Histograms;
import ru.yandex.monlib.metrics.labels.Labels;
import ru.yandex.monlib.metrics.primitives.Histogram;
import ru.yandex.monlib.metrics.registry.MetricRegistry;
import ru.yandex.solomon.config.OptionalSet;
import ru.yandex.solomon.labels.LabelKeys;
import ru.yandex.solomon.labels.query.Selectors;
import ru.yandex.solomon.labels.query.SelectorsBuilder;
import ru.yandex.solomon.labels.query.ShardSelectors;
import ru.yandex.solomon.metrics.client.CrossDcResponseMerger;
import ru.yandex.solomon.metrics.client.CrossShardResponseMerger;
import ru.yandex.solomon.metrics.client.FindRequest;
import ru.yandex.solomon.metrics.client.FindResponse;
import ru.yandex.solomon.metrics.client.MetricsClient;
import ru.yandex.solomon.staffOnly.manager.special.DurationMillis;
import ru.yandex.solomon.util.collection.Nullables;

import static java.util.concurrent.CompletableFuture.completedFuture;

/**
 * @author Ivan Tsybulin
 */
@ParametersAreNonnullByDefault
public class MetabaseFindCacheImpl implements MetabaseFindCache {
    private final MetricsClient metricsClient;
    private final MetricRegistry metricRegistry;
    @DurationMillis
    private final long refreshIntervalMillis;

    private final Clock clock;
    private final Consumer<FindResponse> onComplete;

    private final ConcurrentMap<FindRequestKey, CompletableFuture<FindResponse>> activeRequests = new ConcurrentHashMap<>();

    private final Cache<FindRequestKey, TimeExpiredItem<FindResponse>> cachedResponses;
    private final Cache<ShardBySelectorKey, List<Labels>> shardsCache;

    private final Histogram responseTime;

    private final ConcurrentMap<String, RefreshMetrics> refreshMetricsByDc = new ConcurrentHashMap<>();

    private class RefreshMetrics {
        final Histogram refreshTime;

        RefreshMetrics(String dc) {
            refreshTime = metricRegistry.histogramRate("metabaseCache.refreshTimeMillis", Labels.of("target", dc),
                Histograms.exponential(12, 2, 16));
        }
    }

    private final ConcurrentMap<DcProjectKey, CrossShardMetrics> crossShardMetricsByDcAndProject = new ConcurrentHashMap<>();

    private record DcProjectKey(String dc, String project) {
    }

    private class CrossShardMetrics {
        final Histogram shardCount;

        CrossShardMetrics(DcProjectKey key) {
            shardCount = metricRegistry.histogramRate("metabaseCache.crossShardCount",
                    Labels.of("target", key.dc(), "projectId", key.project()),
                    Histograms.exponential(12, 2, 1));
        }
    }

    public MetabaseFindCacheImpl(MetricsClient metricsClient, FindCacheOptions options) {
        this(metricsClient, new MetricRegistry(), options);
    }

    public MetabaseFindCacheImpl(MetricsClient metricsClient, MetricRegistry metricRegistry, FindCacheOptions options) {
        this(metricsClient, metricRegistry, options, Clock.systemUTC(), null, null);
    }

    @VisibleForTesting
    public MetabaseFindCacheImpl(
            MetricsClient metricsClient,
            MetricRegistry metricRegistry,
            FindCacheOptions options,
            Clock clock,
            @Nullable Ticker ticker,
            @Nullable Consumer<FindResponse> onComplete)
    {
        this.metricsClient = metricsClient;
        this.metricRegistry = metricRegistry;
        this.refreshIntervalMillis = options.getRefreshInterval().toMillis();
        this.clock = clock;
        var cacheBuilder = CacheBuilder.newBuilder()
            .recordStats()
            .expireAfterAccess(options.getExpireTtl().toMillis(), TimeUnit.MILLISECONDS);
        if (ticker != null) {
            cacheBuilder.ticker(ticker);
        }
        OptionalSet.setLong(cacheBuilder::maximumSize, options.getMaxSize());
        this.cachedResponses = cacheBuilder.build();
        this.shardsCache = CacheBuilder.newBuilder()
                .expireAfterWrite(options.getExpireTtl().toMillis(), TimeUnit.MILLISECONDS)
                .ticker(Nullables.orDefault(ticker, Ticker.systemTicker()))
                .build();
        this.onComplete = onComplete;

        responseTime = metricRegistry.histogramRate("metabaseCache.findTimeMillis", Histograms.exponential(12, 2, 1));

        metricRegistry.lazyRate("metabaseCache.stats.missCount", () -> cachedResponses.stats().missCount());
        metricRegistry.lazyRate("metabaseCache.stats.hitCount", () -> cachedResponses.stats().hitCount());
        metricRegistry.lazyRate("metabaseCache.stats.evictionCount", () -> cachedResponses.stats().evictionCount());
        metricRegistry.lazyGaugeInt64("metabaseCache.stats.size", cachedResponses::size);
        metricRegistry.lazyGaugeInt64("metabaseCache.stats.inflight", activeRequests::size);
    }

    private record ShardBySelectorKey(String dc, Selectors shardSelectors) {
    }

    public record FindRequestKey(String dc, Selectors selectors, int limit) {
    }

    public CompletableFuture<FindResponse> find(Selectors selectors, int limit, long softDeadlineMillis, long deadlineMillis) {
        long startMillis = System.currentTimeMillis();
        var future = findCrossDc(selectors, limit, softDeadlineMillis, deadlineMillis);
        return future.whenComplete((ignore, e) -> responseTime.record(System.currentTimeMillis() - startMillis));
    }

    private CompletableFuture<FindResponse> findCrossDc(Selectors selectors, int limit, long softDeadlineMillis, long deadlineMillis) {
        List<CompletableFuture<FindResponse>> futures = metricsClient.getDestinations().stream()
                .map(dc -> findCrossShard(dc, selectors, limit, softDeadlineMillis, deadlineMillis))
                .collect(Collectors.toList());

        var readyResponses = futures.stream()
                .filter(future -> future.isDone() && !future.isCompletedExceptionally())
                .map(future -> future.getNow(null))
                .filter(FindResponse::isOk)
                .collect(Collectors.toList());

        if (readyResponses.isEmpty()) {
            return CompletableFutures.allOf(futures).thenApply(responses -> mergeCrossDc(responses, limit));
        } else {
            return completedFuture(mergeCrossDc(readyResponses, limit));
        }
    }

    private CompletableFuture<FindResponse> findCrossShard(String dc, Selectors selectors, int limit, long softDeadlineMillis, long deadlineMillis) {
        try {
            var split = ShardSelectors.split(selectors);
            var shardSelector = split.shardSelector();
            var projectSelector = shardSelector.findByKey(LabelKeys.PROJECT);
            String project = "UNKNOWN";
            if (projectSelector != null && projectSelector.isExact()) {
                project = projectSelector.getValue();
            }
            var shards = metabaseShards(dc, shardSelector);

            crossShardMetricsByDcAndProject.computeIfAbsent(new DcProjectKey(dc, project), CrossShardMetrics::new)
                    .shardCount.record(shards.size());
            crossShardMetricsByDcAndProject.computeIfAbsent(new DcProjectKey(dc, "total"), CrossShardMetrics::new)
                    .shardCount.record(shards.size());

            return shards.stream()
                    .map(shardKey -> {
                        final FindRequestKey key = new FindRequestKey(dc, addLabels(split.metricsSelector(), shardKey), limit);
                        TimeExpiredItem<FindResponse> cachedResponse = cachedResponses.getIfPresent(key);
                        if (cachedResponse == null) {
                            return startFind(key, softDeadlineMillis, deadlineMillis);
                        }

                        if (cachedResponse.isExpired(clock.millis())) {
                            // avoid wait, warm cache only
                            startFind(key, softDeadlineMillis, deadlineMillis);
                        }

                        return completedFuture(cachedResponse.getPayload());
                    }).collect(Collectors.collectingAndThen(Collectors.toList(), CompletableFutures::allOf))
                    .thenApply(response -> CrossShardResponseMerger.mergeFindResponses(response, limit));
        } catch (Throwable e) {
            return CompletableFuture.failedFuture(e);
        }
    }

    private List<Labels> metabaseShards(String destination, Selectors shardSelector) {
        final var key = new ShardBySelectorKey(destination, shardSelector);
        var shards = shardsCache.getIfPresent(key);
        if (shards == null) {
            shards = metricsClient.metabaseShards(destination, shardSelector).collect(Collectors.toList());
            shardsCache.put(key, shards);
        }
        return shards;
    }

    private Selectors addLabels(Selectors selectors, Labels labels) {
        if (labels.isEmpty()) {
            return selectors;
        }

        SelectorsBuilder builder = Selectors.builder(selectors.size() + labels.size());
        labels.forEach(l -> builder.add(l.getKey(), l.getValue()));
        builder.addAll(selectors);
        return builder.build();
    }

    private CompletableFuture<FindResponse> startFind(FindRequestKey findRequestKey, long softDeadline, long deadline) {
        var prev = activeRequests.get(findRequestKey);
        if (prev != null) {
            return prev;
        }

        final var future = activeRequests.computeIfAbsent(findRequestKey,
                (keyCreate) -> makeFindRequest(keyCreate, softDeadline, deadline));
        return future.whenComplete((response, e) -> {
            activeRequests.remove(findRequestKey, future);
            if (onComplete != null) {
                onComplete.accept(response);
            }
        });
    }

    private CompletableFuture<FindResponse> makeFindRequest(FindRequestKey key, long softDeadline, long deadlineMillis) {
        long startMillis = System.currentTimeMillis();
        return metricsClient.find(FindRequest.newBuilder()
                .setDestination(key.dc())
                .setSelectors(key.selectors())
                .setLimit(key.limit())
                .setSoftDeadline(softDeadline)
                .setDeadline(deadlineMillis)
                .build()
            )
            .whenComplete((ignore, e) -> {
                RefreshMetrics refreshMetrics = refreshMetricsByDc.computeIfAbsent(key.dc(), RefreshMetrics::new);
                refreshMetrics.refreshTime.record(System.currentTimeMillis() - startMillis);
            })
            .thenApply(response -> {
                if (response.isOk()) {
                    cachedResponses.put(key, new TimeExpiredItem<>(response, refreshIntervalMillis, clock.millis()));
                }
                return response;
            });
    }

    private FindResponse mergeCrossDc(List<FindResponse> responses, int limit) {
        FindRequest request = FindRequest.newBuilder()
            .setSelectors(Selectors.of())
            .setLimit(limit)
            .build();
        return CrossDcResponseMerger.mergeFindResponses(request, responses, metricsClient.getDestinations().size());
    }

    @VisibleForTesting
    public long activeCount() {
        return activeRequests.size();
    }

    @VisibleForTesting
    public Cache<FindRequestKey, TimeExpiredItem<FindResponse>> getCache() {
        return cachedResponses;
    }
}
