package ru.yandex.travel.commons.retry;

import java.io.Closeable;
import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.function.Supplier;

import io.grpc.Context;
import io.opentracing.Scope;
import io.opentracing.Span;
import io.opentracing.Tracer;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.MDC;

import ru.yandex.misc.ExceptionUtils;
import ru.yandex.travel.commons.concurrent.ExecutorUtils;
import ru.yandex.travel.commons.logging.NestedMdc;

@Slf4j
public class Retry implements Closeable {
    private final ScheduledExecutorService executor;
    private final Tracer tracer;

    public Retry(Tracer tracer) {
        this.tracer = tracer;
        this.executor = new ScheduledThreadPoolExecutor(1);
        ((ScheduledThreadPoolExecutor) this.executor).setMaximumPoolSize(1);
    }

    public <I, T> Function<I, CompletableFuture<T>> wrapWithCtx(Function<I, CompletableFuture<T>> futureSupplier) {
        Span activeSpan = tracer.activeSpan();
        Context ctx = Context.current();

        return (arg) -> {
            try (Scope scope = tracer.activateSpan(activeSpan)) {
                return ctx.call(() -> futureSupplier.apply(arg));
            } catch (Exception e) {
                throw ExceptionUtils.throwException(e);
            }
        };
    }


    public <T, I> CompletableFuture<T> withRetry(String name, Function<I, CompletableFuture<T>> futureSupplier,
                                                 I input) {

        DefaultRetryStrategy<T> strategy = new DefaultRetryStrategy<>();
        return withRetry(name, wrapWithCtx(futureSupplier), input, strategy, 0, strategy.getNumRetries(),
                new ArrayList<>(), null);
    }

    public <T, I> CompletableFuture<T> withRetry(String name, Function<I, CompletableFuture<T>> futureSupplier,
                                                 I input, RetryStrategy<T> strategy) {
        return withRetry(name, wrapWithCtx(futureSupplier), input, strategy, 0, strategy.getNumRetries(),
                new ArrayList<>(), null);
    }

    public <T, I> CompletableFuture<T> withRetry(String name, Function<I, CompletableFuture<T>> futureSupplier,
                                                 I input, RetryStrategy<T> strategy, RetryRateLimiter rateLimiter) {
        return withRetry(name, wrapWithCtx(futureSupplier), input, strategy, 0, strategy.getNumRetries(),
                new ArrayList<>(), rateLimiter);
    }

    public <T> CompletableFuture<T> withRetry(String name, Supplier<CompletableFuture<T>> futureSupplier,
                                              RetryStrategy<T> strategy) {
        return withRetry(name, wrapWithCtx(ignored -> futureSupplier.get()), null, strategy, 0,
                strategy.getNumRetries(), new ArrayList<>(), null);
    }

    public <T> CompletableFuture<T> withRetry(String name, Supplier<CompletableFuture<T>> futureSupplier,
                                              RetryStrategy<T> strategy, RetryRateLimiter rateLimiter) {
        return withRetry(name, wrapWithCtx(ignored -> futureSupplier.get()), null, strategy, 0,
                strategy.getNumRetries(), new ArrayList<>(), rateLimiter);
    }


    private <T, I> CompletableFuture<T> withRetry(String name,
                                                  Function<I, CompletableFuture<T>> futureSupplier,
                                                  I input, RetryStrategy<T> strategy,
                                                  int iteration,
                                                  int remainingRetries,
                                                  List<Exception> attempts,
                                                  RetryRateLimiter rateLimiter) {
        final long startedAt = System.currentTimeMillis();
        Map<String, String> mdc = MDC.getCopyOfContextMap();  //get copy of MDC to pass it later in case of failure
        CompletableFuture<T> future = futureSupplier.apply(input);
        return future.handle((r, e) -> wrap(name, r, e, iteration, System.currentTimeMillis() - startedAt)).thenCompose(wrapped -> {
            boolean retryIsNeeded = false;
            Exception exception = wrapped.getException();
            Exception attemptException = exception;
            T result = wrapped.getResult();
            if (exception != null) {
                retryIsNeeded = strategy.shouldRetryOnException(exception);
                log.warn("{}: Iteration {} failed because of exception", name, iteration, exception);
            } else if (result != null) {
                try {
                    strategy.validateResult(result);
                } catch (Exception exc) {
                    retryIsNeeded = true;
                    attemptException = new UnacceptedResultException(result);
                    log.warn("{}: Iteration {} failed because of bad result: {}", name, iteration, exc.getMessage());
                }
            }
            if (retryIsNeeded) {
                if (rateLimiter != null) {
                    rateLimiter.onFailure();
                }
                attempts.add(attemptException);
                boolean retryIsPossible = remainingRetries > 0;
                if (retryIsPossible) {
                    if (rateLimiter != null) {
                        retryIsPossible = rateLimiter.shouldRetry();
                        if (!retryIsPossible) {
                            log.warn("{}: Iteration {} needs a retry, but rate limit is reached", name, iteration);
                        }
                    }
                } else {
                    log.warn("{}: Iteration {} needs a retry, but no more attempts left", name, iteration);
                }
                if (retryIsPossible) {
                    return await(strategy.getWaitDuration(iteration, exception, result)).thenCompose(ignored -> {
                        try (var ignoredEmptyMdc = NestedMdc.empty()) {  //save original MDC and put initial one
                            if (mdc != null) {
                                MDC.setContextMap(mdc);
                            }
                            return withRetry(name, futureSupplier, input, strategy, iteration + 1,
                                    remainingRetries - 1, attempts, rateLimiter);
                        }
                    });
                } else {
                    CompletableFuture<T> futureToReturn = new CompletableFuture<>();
                    futureToReturn.completeExceptionally(new RetryException(attempts));
                    return futureToReturn;
                }
            } else {
                if (rateLimiter != null) {
                    rateLimiter.onSuccess();
                }
                CompletableFuture<T> futureToReturn = new CompletableFuture<>();
                if (exception != null) {
                    futureToReturn.completeExceptionally(exception);
                } else {
                    futureToReturn.complete(wrapped.getResult());
                }
                return futureToReturn;
            }
        });
    }

    public <T, I> CompletableFuture<T> withSpeculativeRetry(String name,
                                                            Function<I, CompletableFuture<T>> futureSupplier,
                                                            I input, SpeculativeRetryStrategy<T> strategy,
                                                            RetryRateLimiter rateLimiter) {
        var retryDelay = strategy.getRetryDelay();
        final long startedAt = System.currentTimeMillis();

        CompletableFuture<T> resultFuture = new CompletableFuture<>();
        var futures = Collections.synchronizedList(new ArrayList<CompletableFuture<T>>());
        CompletableFuture<Boolean> currFuture = CompletableFuture.completedFuture(true);
        for (int i = 0; i < strategy.getNumRetries(); i++) {
            var iteration = i + 1;
            currFuture = currFuture.thenCompose(shouldContinue -> {
                if (shouldContinue) {
                    var currentFutureStartedAt = System.currentTimeMillis();
                    futures.add(futureSupplier.apply(input).whenComplete((result, exception) -> {
                        if (exception != null) {
                            Exception actual = unwrapExecutionException((Exception) exception);
                            log.warn("{}: Execution attempt {} completed in {} ms ({} ms total) with an exception '{}'",
                                    name, iteration, System.currentTimeMillis() - currentFutureStartedAt,
                                    System.currentTimeMillis() - startedAt, actual.getMessage());
                            if (!strategy.shouldRetryOnException.apply(actual)) {
                                resultFuture.completeExceptionally(actual);
                            }
                        } else {
                            log.debug("{}: Execution attempt {} completed in {} ms ({} ms total)", name, iteration,
                                    System.currentTimeMillis() - currentFutureStartedAt,
                                    System.currentTimeMillis() - startedAt);
                            try {
                                if (strategy.validateResult != null) {
                                    strategy.validateResult.accept(result);
                                }
                                resultFuture.complete(result);
                            } catch (Exception exc) {
                                log.warn("{}: Iteration {} failed because of bad result: {}", name, iteration,
                                        exc.getMessage());
                            }
                        }
                    }));
                    return CompletableFuture
                            .anyOf(resultFuture, await(retryDelay))
                            .thenApply(ignored -> {
                                if (!resultFuture.isDone()) {
                                    rateLimiter.onFailure();
                                    if (rateLimiter.shouldRetry()) {
                                        return true;
                                    }
                                    log.warn("{}: Not retrying because of rate limit", name);
                                } else {
                                    rateLimiter.onSuccess();
                                }
                                return false;
                            });
                }
                return CompletableFuture.completedFuture(false);
            });
        }

        var timeout = Objects.requireNonNullElseGet(strategy.getTimeout(), () -> Duration.ofDays(1)); // 1 days stands for "no timeout"
        return CompletableFuture.anyOf(resultFuture, await(timeout)).thenApply(ignored -> {
            futures.forEach(x -> x.cancel(true));
            if (resultFuture.isDone()) {
                if (resultFuture.isCompletedExceptionally()) {
                    log.debug("{}: Finished with ignored exception after {} ms", name,
                            System.currentTimeMillis() - startedAt);
                } else {
                    log.debug("{}: Finished successfully after {} ms", name, System.currentTimeMillis() - startedAt);
                }
                return resultFuture.join();
            }
            var failedFuture = futures.stream().filter(CompletableFuture::isCompletedExceptionally).findFirst();
            var errorMsg = String.format("%s: No successful attempt after %s ms", name, timeout.toMillis());
            if (failedFuture.isPresent()) {
                try {
                    failedFuture.get().join(); // it's completed exceptionally
                } catch (Exception e) {
                    throw new RuntimeException(errorMsg, unwrapExecutionException(e));
                }
            }
            throw new RuntimeException(errorMsg);
        });
    }

    public static Exception unwrapExecutionException(Exception exception) {
        while (exception instanceof CompletionException && exception.getCause() != null && exception.getCause() instanceof Exception) {
            exception = (Exception) exception.getCause();
        }
        return exception;
    }

    private CompletableFuture<Void> await(Duration timeout) {
        CompletableFuture<Void> result = new CompletableFuture<>();
        executor.schedule(() -> result.complete(null), timeout.toMillis(), TimeUnit.MILLISECONDS);
        return result;
    }

    private <T> RetryResultWrapper<T> wrap(String name, T result, Throwable exception, int iteration,
                                           long durationMillis) {
        Exception actual = unwrapExecutionException((Exception) exception);
        if (exception != null) {
            log.warn("{}: Execution attempt {} completed in {} ms with an exception '{}'", name, iteration,
                    durationMillis, exception.getMessage());
        } else {
            log.debug("{}: Execution attempt {} completed in {} ms", name, iteration, durationMillis);
        }
        return new RetryResultWrapper<>(result, actual);
    }

    @Override
    public void close() throws IOException {
        ExecutorUtils.shutdownAndAwaitTermination(this.executor, Duration.ofSeconds(1));
    }
}
