package ru.yandex.solomon.name.resolver.client;

import java.time.Clock;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Stream;

import javax.annotation.ParametersAreNonnullByDefault;
import javax.annotation.WillCloseWhenClosed;

import ru.yandex.misc.concurrent.CompletableFutures;

import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.stream.Collectors.collectingAndThen;
import static java.util.stream.Collectors.toList;


/**
 * @author Alexey Trushkin
 */
@ParametersAreNonnullByDefault
public class CrossDcNameResolverClient implements NameResolverClient {

    private final Clock clock;
    @WillCloseWhenClosed
    private final Map<String, NameResolverClient> clientByDestination;
    private final long slowResponseAwaitMillis = 3_000;
    private final long slowResponseAwaitMaxLimitMillis = 60_000;

    public CrossDcNameResolverClient(@WillCloseWhenClosed Map<String, NameResolverClient> clientByDestination) {
        this(Clock.systemUTC(), clientByDestination);
    }

    public CrossDcNameResolverClient(Clock clock, @WillCloseWhenClosed Map<String, NameResolverClient> clientByDestination) {
        this.clock = clock;
        this.clientByDestination = clientByDestination;
    }

    @Override
    public CompletableFuture<FindResponse> find(FindRequest request) {
        return resolveClient()
                .map(client -> client.find(request))
                .collect(collectingAndThen(toList(), waitSome()))
                .thenApply(responses -> {
                    ensureDestinationResponses(responses);
                    var response = responses.get(0);
                    for (FindResponse res : responses) {
                        if (res.resources.size() > response.resources.size()) {
                            response = res;
                        }
                    }
                    return response;
                });
    }

    @Override
    public CompletableFuture<ResolveResponse> resolve(ResolveRequest request) {
        return resolveClient()
                .map(client -> client.resolve(request))
                .collect(collectingAndThen(toList(), waitSome()))
                .thenApply(responses -> {
                    ensureDestinationResponses(responses);
                    if (responses.size() == 1) {
                        return responses.get(0);
                    }
                    Map<String, Resource> resources = new HashMap<>(responses.get(0).resources.size());
                    for (var response : responses) {
                        for (Resource resource : response.resources) {
                            var prev = resources.get(resource.resourceId);
                            if (prev == null) {
                                resources.put(resource.resourceId, resource);
                            } else if (prev.updatedAt < resource.updatedAt) {
                                resources.replace(resource.resourceId, resource);
                            }
                        }
                    }
                    return new ResolveResponse(new ArrayList<>(resources.values()));
                });
    }

    @Override
    public CompletableFuture<Void> update(UpdateRequest request) {
        return resolveClient()
                .map(client -> client.update(request))
                .collect(collectingAndThen(toList(), waitSome()))
                .thenApply(responses -> {
                    ensureDestinationResponses(responses);
                    return null;
                });
    }

    @Override
    public CompletableFuture<ShardsResponse> getShardIds() {
        return resolveClient()
                .map(NameResolverClient::getShardIds)
                .collect(collectingAndThen(toList(), waitSome()))
                .thenApply(responses -> {
                    ensureDestinationResponses(responses);
                    if (responses.size() == 1) {
                        return responses.get(0);
                    }
                    var set = new HashSet<>(responses.get(0).ids());
                    for (int i = 1; i < responses.size(); i++) {
                        set.addAll(responses.get(i).ids());
                    }
                    return new ShardsResponse(set);
                });
    }

    private Stream<NameResolverClient> resolveClient() {
        return clientByDestination.values().stream();
    }

    @Override
    public void close() {
        clientByDestination.values().forEach(NameResolverClient::close);
    }

    private <T> Function<List<CompletableFuture<T>>, CompletableFuture<List<T>>> waitSome() {
        return l -> anyOrAllOf(l, 0);
    }

    <T> CompletableFuture<List<T>> anyOrAllOf(
            List<CompletableFuture<T>> futures,
            long softDeadline) {
        if (futures.isEmpty()) {
            return completedFuture(List.of());
        }

        if (futures.size() == 1) {
            return futures.get(0).thenApply(List::of);
        }

        var allOfFuture = CompletableFutures.allOf(futures);

        return anyOk(futures).thenCompose(someIsOk -> {
            if (someIsOk) {
                // wait some more time for lagging DC
                long extraAwaitMillis = softDeadline == 0 ? slowResponseAwaitMillis :
                        Math.min(softDeadline - clock.millis(), slowResponseAwaitMaxLimitMillis);
                return allOfFuture
                        .exceptionally(t -> null)
                        .completeOnTimeout(null, extraAwaitMillis, TimeUnit.MILLISECONDS)
                        .thenApply(list -> {
                            if (list != null) {
                                return list;
                            }
                            List<T> completed = new ArrayList<>(futures.size());
                            for (var future : futures) {
                                if (future.isDone() && !future.isCompletedExceptionally()) {
                                    completed.add(future.getNow(null));
                                }
                            }
                            return completed;
                        });
            } else {
                // All futures are completed by the time
                return allOfFuture.thenApply(listF -> listF);
            }
        });
    }

    <T> CompletableFuture<Boolean> anyOk(List<CompletableFuture<T>> futures) {
        if (futures.size() == 0) {
            return completedFuture(false);
        }

        @SuppressWarnings("unchecked")
        CompletableFuture<T>[] futureArray = futures.toArray(CompletableFuture[]::new);

        return CompletableFuture.anyOf(futureArray)
                .exceptionally(t -> null)
                .thenCompose(ignore -> {
                    // precompute to avoid race
                    Boolean[] isDone = futures.stream()
                            .map(CompletableFuture::isDone)
                            .toArray(Boolean[]::new);
                    List<CompletableFuture<T>> incomplete = new ArrayList<>(futureArray.length);
                    for (int i = 0; i < futureArray.length; i++) {
                        if (isDone[i]) {
                            var future = futureArray[i];
                            if (!future.isCompletedExceptionally()) {
                                return completedFuture(true);
                            }
                        } else {
                            incomplete.add(futureArray[i]);
                        }
                    }
                    return anyOk(incomplete);
                });
    }

    private static <T> void ensureDestinationResponses(List<T> responses) {
        if (responses.isEmpty()) {
            throw new IllegalArgumentException("Response is not ok");
        }
    }
}
