package ru.yandex.mail.cerberus.client;

import io.micronaut.aop.InterceptPhase;
import io.micronaut.aop.MethodInterceptor;
import io.micronaut.aop.MethodInvocationContext;
import io.micronaut.http.client.exceptions.HttpClientResponseException;
import io.micronaut.http.client.exceptions.ReadTimeoutException;
import io.micronaut.retry.annotation.Retryable;
import lombok.val;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Mono;

import javax.annotation.Nonnull;
import javax.inject.Singleton;
import java.util.function.Supplier;

import static lombok.Lombok.sneakyThrow;

@Singleton
public class RetryInterceptor implements MethodInterceptor<Object, Object> {
    private static class RetryableException extends RuntimeException {
        public RetryableException(Throwable cause) {
            super(cause);
        }
    }

    private static boolean isNeedRetry(@Nonnull Throwable e) {
        if (e instanceof ReadTimeoutException) {
            return true;
        }

        if (!(e instanceof HttpClientResponseException)) {
            return false;
        }

        val responseException = (HttpClientResponseException) e;
        return responseException.getStatus().getCode() / 100 == 5;
    }

    @SuppressWarnings("unchecked")
    private static <T> T wrapException(Supplier<T> call) {
        try {
            val result = call.get();
            if (result == null) {
                return null;
            }

            if (Publisher.class.isAssignableFrom(result.getClass())) {
                return (T) Mono.from((Publisher<?>) result)
                    .onErrorMap(e -> {
                        if (isNeedRetry(e)) {
                            return new RetryableException(e);
                        } else {
                            return e;
                        }
                    });
            } else {
                return result;
            }
        } catch (Exception e) {
            if (isNeedRetry(e)) {
                throw new RetryableException(e);
            } else {
                throw e;
            }
        }
    }

    private static Object unwrapException(Supplier<Object> call) {
        try {
            val result = call.get();
            if (result == null) {
                return null;
            }

            if (Publisher.class.isAssignableFrom(result.getClass())) {
                return Mono.from((Publisher<?>) result)
                    .onErrorMap(e -> {
                        if (e instanceof RetryableException) {
                            return e.getCause();
                        } else {
                            return e;
                        }
                    });
            } else {
                return result;
            }
        } catch (RetryableException e) {
            throw sneakyThrow(e.getCause());
        }
    }

    @Retryable(
        includes = RetryableException.class,
        attempts = "${micronaut.http.services.cerberus.retries-count:0}",
        delay = "${micronaut.http.services.cerberus.retry-delay:0ms}"
    )
    public Publisher<?> retryReactive(Supplier<Publisher<?>> publisherSupplier) {
        return wrapException(publisherSupplier);
    }

    @Retryable(
        includes = RetryableException.class,
        attempts = "${micronaut.http.services.cerberus.retries-count:0}",
        delay = "${micronaut.http.services.cerberus.retry-delay:0ms}"
    )
    public Object retry(Supplier<?> completionSupplier) {
        return wrapException(completionSupplier);
    }

    @Override
    public Object intercept(MethodInvocationContext<Object, Object> context) {
        val resultType = context.getReturnType().getType();

        return unwrapException(() -> {
            if (Publisher.class.isAssignableFrom(resultType)) {
                return retryReactive(() -> (Publisher<?>) context.proceed());
            } else {
                return retry(context::proceed);
            }
        });
    }

    @Override
    public int getOrder() {
        return InterceptPhase.RETRY.getPosition();
    }
}
