package ru.yandex.mail.cerberus.dao.tx;

import io.micronaut.configuration.jdbc.hikari.DatasourceConfiguration;
import io.micronaut.context.BeanLocator;
import io.micronaut.context.annotation.EachBean;
import io.micronaut.context.annotation.Parameter;
import io.micronaut.inject.qualifiers.Qualifiers;
import io.micronaut.retry.annotation.Retryable;
import lombok.SneakyThrows;
import lombok.val;
import org.springframework.dao.TransientDataAccessException;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.annotation.Isolation;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.support.TransactionTemplate;
import ru.yandex.mail.cerberus.dao.ExtendedDatasourceConfiguration;
import ru.yandex.mail.cerberus.exception.NoDatabaseConnectionException;

import javax.inject.Inject;
import javax.inject.Singleton;
import javax.sql.DataSource;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;

import static java.util.Objects.requireNonNullElse;
import static ru.yandex.mail.cerberus.dao.DaoConstants.RETRY_POLICY_ATTEMPTS_PROPERTY;
import static ru.yandex.mail.cerberus.dao.DaoConstants.RETRY_POLICY_DELAY_PROPERTY;

// NOTE: temporary solution till micronaut doesn't support base exception class filtering in @Retryable
class RetryException extends RuntimeException {
    public RetryException(Throwable cause) {
        super(cause);
    }
}

@Singleton
@EachBean(DataSource.class)
public class DatasourceTxManager implements TxManager {
    private final ExecutorService executor;
    private final ExecutorService forwardingExecutor;
    private final PlatformTransactionManager transactionManager;
    private final boolean readOnly;
    private final long connectionTimeout;

    private TransactionTemplate createTemplate(Isolation isolation, Propagation propagation) {
        val template = new TransactionTemplate(transactionManager);
        template.setReadOnly(readOnly);
        template.setIsolationLevel(isolation.value());
        template.setPropagationBehavior(propagation.value());
        return template;
    }

    private DatasourceTxManager(String name, BeanLocator beanLocator, Optional<ExecutorService> forwardingExecutor) {
        val dataSourceConfig = beanLocator.getBean(DatasourceConfiguration.class, Qualifiers.byName(name));
        val dataSourceConfigExt = beanLocator.getBean(ExtendedDatasourceConfiguration.class, Qualifiers.byName(name));

        this.executor = beanLocator.getBean(ExecutorService.class, Qualifiers.byName(name));
        this.forwardingExecutor = forwardingExecutor
            .orElseGet(() -> {
                val forwardingExecutorName = dataSourceConfigExt.getForwardingExecutor();
                return beanLocator.getBean(ExecutorService.class, Qualifiers.byName(forwardingExecutorName));
            });
        this.transactionManager = beanLocator.getBean(PlatformTransactionManager.class, Qualifiers.byName(name));
        this.readOnly = dataSourceConfigExt.isReadOnly();
        this.connectionTimeout = TimeUnit.NANOSECONDS.convert(dataSourceConfig.getConnectionTimeout(), TimeUnit.MILLISECONDS);
    }

    @Retryable(includes = RetryException.class, attempts = RETRY_POLICY_ATTEMPTS_PROPERTY, delay = RETRY_POLICY_DELAY_PROPERTY)
    protected <T> T retry(Supplier<T> call) {
        try {
            return call.get();
        } catch (TransientDataAccessException e) {
            throw new RetryException(e);
        }
    }

    @Inject
    public DatasourceTxManager(@Parameter String name, BeanLocator beanLocator) {
        this(name, beanLocator, Optional.empty());
    }

    DatasourceTxManager(String name, BeanLocator beanLocator, ExecutorService forwardingExecutor) {
        this(name, beanLocator, Optional.of(forwardingExecutor));
    }

    @Override
    public void run(Runnable runnable, Isolation isolation, Propagation propagation) {
        execute(() -> {
            runnable.run();
            return null;
        }, isolation, propagation);
    }

    @Override
    public CompletableFuture<Void> runAsync(Runnable runnable, Isolation isolation, Propagation propagation) {
        return executeAsync(() -> {
            runnable.run();
            return null;
        }, isolation, propagation);
    }

    @Override
    @SneakyThrows
    public <T> T execute(Supplier<T> call, Isolation isolation, Propagation propagation) {
        try {
            return retry(() -> {
                val template = createTemplate(isolation, propagation);
                return template.execute(status -> call.get());
            });
        } catch (RetryException e) {
            throw e.getCause();
        }
    }

    @Override
    public <T> CompletableFuture<T> executeAsync(Supplier<T> call, Isolation isolation, Propagation propagation) {
        val future = new CompletableFuture<T>() {
            @Override
            public Executor defaultExecutor() {
                return forwardingExecutor;
            }
        };

        val time = System.nanoTime();

        try {
            executor.execute(() -> {
                try {
                    val elapsedTime = System.nanoTime() - time;
                    if (elapsedTime > connectionTimeout) {
                        throw new NoDatabaseConnectionException();
                    }

                    val result = execute(call, isolation, propagation);
                    future.complete(result);
                } catch (Throwable e) {
                    future.completeExceptionally(e);
                }
            });
        } catch (RejectedExecutionException e) {
            future.completeExceptionally(requireNonNullElse(e.getCause(), e));
        }

        return future;
    }
}
