package ru.yandex.market.graphouse.stockpile.proxy;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import ru.yandex.misc.concurrent.CompletableFutures;
import ru.yandex.solomon.model.protobuf.MetricType;
import ru.yandex.solomon.model.timeseries.AggrGraphDataIterable;
import ru.yandex.solomon.model.timeseries.MergingAggrGraphDataIterable;
import ru.yandex.solomon.selfmon.AvailabilityStatus;
import ru.yandex.solomon.staffOnly.annotations.ManagerMethod;
import ru.yandex.solomon.staffOnly.annotations.ManagerMethodArgument;
import ru.yandex.stockpile.api.MetricMeta;
import ru.yandex.stockpile.client.shard.StockpileLocalId;
import ru.yandex.stockpile.client.shard.StockpileMetricId;
import ru.yandex.stockpile.client.writeRequest.StockpileShardWriteRequest;

import static java.util.stream.Collectors.collectingAndThen;
import static java.util.stream.Collectors.toList;

/**
 * @author Vladimir Gordiychuk
 */
public class CrossDcStockpileClient implements GraphiteStockpileClient {
    private GraphiteStockpileClient master;
    private GraphiteStockpileClient slave;

    private volatile boolean crossDcRead = true;
    private volatile long slowResponseAwaitMillis = 3_000;

    public CrossDcStockpileClient(GraphiteStockpileClient master, GraphiteStockpileClient slave) {
        this.master = master;
        this.slave = slave;
    }

    @Override
    public StockpileMetricId generateMetricId() {
        int shardId = ThreadLocalRandom.current().nextInt(1, getTotalShardsCount() + 1);
        long localId = StockpileLocalId.random();
        return new StockpileMetricId(shardId, localId);
    }

    @Override
    public int getTotalShardsCount() {
        return Math.min(master.getTotalShardsCount(), slave.getTotalShardsCount());
    }

    @Override
    public boolean isFullyReady() {
        return master.isFullyReady();
    }

    @Override
    public CompletableFuture<ReadResponse> readOne(ReadRequest request) {
        if (!crossDcRead) {
            return master.readOne(request);
        }

        var key = request.getKey();
        return Stream.of(master, slave)
            .map(client -> client.readOne(request))
            .collect(Collectors.collectingAndThen(toList(), list -> anyOrAllOf(list, slowResponseAwaitMillis)))
            .thenApply(responses -> {
                if (responses.size() == 1) {
                    return responses.get(0);
                }

                if (responses.stream().noneMatch(ReadResponse::isOk)) {
                    return responses.get(0);
                }

                AggrGraphDataIterable merged = responses.stream()
                    .filter(ReadResponse::isOk)
                    .map(ReadResponse::getSource)
                    .collect(collectingAndThen(toList(), MergingAggrGraphDataIterable::of));

                MetricType dataType = responses.stream()
                    .map(ReadResponse::getDataType)
                    .filter(k -> k != MetricType.METRIC_TYPE_UNSPECIFIED)
                    .findFirst()
                    .orElse(MetricType.METRIC_TYPE_UNSPECIFIED);

                return new ReadResponse(key, dataType, merged);
            });
    }

    @Override
    public AvailabilityStatus getStatus() {
        return master.getStatus();
    }

    @Override
    public CompletableFuture<Void> writeData(int shardId, StockpileShardWriteRequest logEntry) {
        return master.writeData(shardId, logEntry);
    }

    @Override
    public CompletableFuture<List<MetricMeta>> readMetricsMeta(int shardId, long[] localIds) {
        return master.readMetricsMeta(shardId, localIds);
    }

    static CompletableFuture<List<ReadResponse>> anyOrAllOf(
        List<? extends CompletableFuture<ReadResponse>> futures,
        long maxAwaitTimeMillis)
    {
        if (futures.isEmpty()) {
            return CompletableFuture.completedFuture(List.of());
        }

        if (futures.size() == 1) {
            return futures.get(0).thenApply(List::of);
        }

        @SuppressWarnings("unchecked")
        CompletableFuture<ReadResponse>[] futuresArray = futures.toArray(new CompletableFuture[0]);
        CompletableFuture<List<ReadResponse>> allOnFuture = CompletableFutures.allOf(futures).thenApply(list -> list);
        return CompletableFuture.anyOf(futuresArray)
            .thenCompose(ignore -> {
                List<ReadResponse> completed = new ArrayList<>(futuresArray.length);
                boolean anySuccess = false;
                for (CompletableFuture<ReadResponse> future : futuresArray) {
                    if (future.isDone() && !future.isCompletedExceptionally()) {
                        ReadResponse response = future.getNow(null);
                        completed.add(response);
                        anySuccess |= response.isOk();
                    }
                }

                if (anySuccess) {
                    // same cluster should return response with the same time, 1 second should be
                    // enough to smooth gc, network overhead etc. Don't care at this place about
                    // remaining time for requests, because each request send with own deadline.
                    return allOnFuture.completeOnTimeout(completed, maxAwaitTimeMillis, TimeUnit.MILLISECONDS);
                } else {
                    return allOnFuture;
                }
            });
    }

    @ManagerMethod
    public void setSlowResponseAwaitMillis(@ManagerMethodArgument(name = "timeMillis") long timeMillis) {
        this.slowResponseAwaitMillis = timeMillis;
    }

    @ManagerMethod
    public void setCrossDcRead(boolean enable) {
        this.crossDcRead = enable;
    }
}
