package ru.yandex.chemodan.ydb.dao;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executors;
import java.util.function.Supplier;

import com.yandex.ydb.core.Result;
import com.yandex.ydb.core.Status;
import com.yandex.ydb.table.Session;
import com.yandex.ydb.table.SessionRetryContext;
import com.yandex.ydb.table.TableClient;
import com.yandex.ydb.table.transaction.Transaction;
import com.yandex.ydb.table.transaction.TransactionMode;
import com.yandex.ydb.table.transaction.TxControl;
import lombok.AllArgsConstructor;
import lombok.Data;

import ru.yandex.bolts.collection.Option;
import ru.yandex.bolts.function.Function;
import ru.yandex.bolts.function.Function2;
import ru.yandex.misc.ExceptionUtils;
import ru.yandex.misc.log.mlf.Logger;
import ru.yandex.misc.log.mlf.LoggerFactory;

/**
 * @author tolmalev
 */
public class ThreadLocalYdbTransactionManager {
    private static final Logger logger = LoggerFactory.getLogger(ThreadLocalYdbTransactionManager.class);

    private final TableClient tableClient;
    private final ThreadLocal<TransactionData> current = new ThreadLocal<>();
    private final YdbTimeoutSettings timeoutSettings;
    private final SessionRetryContext sessionRetryContext;

    public ThreadLocalYdbTransactionManager(TableClient tableClient, YdbTimeoutSettings timeoutSettings,
                                            int retriesCount, int sessionRetryContextThreads) {
        this.tableClient = tableClient;
        this.timeoutSettings = timeoutSettings;
        this.sessionRetryContext = SessionRetryContext.create(tableClient)
                .executor(Executors.newFixedThreadPool(sessionRetryContextThreads))
                .maxRetries(retriesCount)
                .sessionSupplyTimeout(timeoutSettings.getTxTimeout())
                .build();
    }

    public YdbTimeoutSettings getTimeoutSettings() {
        return timeoutSettings;
    }

    public ThreadLocalYdbTransactionManager(TableClient tableClient, YdbTimeoutSettings timeoutSettings) {
        this(tableClient, timeoutSettings, 5, 10);
    }

    public <T> T executeInTransactionOrTmpSession(Function2<Session, TxControl, CompletableFuture<Result<T>>> op) {
        Option<TransactionData> transactionData = getCurrent();
        if(transactionData.isPresent()) {
            TransactionData t = transactionData.get();
            return getResult(op.apply(t.session, TxControl.id(t.transaction).setCommitTx(false)));
        } else {
            return executeInTmpSession(op);
        }
    }

    public <T> T executeInTx(Supplier<T> op, TransactionMode mode) {
        return sessionRetryContext.supplyResult(session -> {
            Transaction transaction = session.beginTransaction(mode, timeoutSettings.getBeginTxSettingsWithTimeout())
                    .join().expect("Can't begin transaction session");
            setCurrent(session, transaction);
            try {
                logger.debug("Starting operation in tx");
                T result = op.get();
                logger.debug("Operation completed, trying commit");
                commit();
                return CompletableFuture.completedFuture(Result.success(result));
            } catch (Throwable t) {
                ExceptionUtils.throwIfUnrecoverable(t);
                logger.error("Failed to execute transaction in session {} due to {}", session, t);
                rollback();
                throw t;
            } finally {
                clearCurrent();
            }
        }).join().expect("Can't execute transactional query");
    }

    public Status executeInTmpSession(Function<Session, CompletableFuture<Status>> op) {
        return sessionRetryContext.supplyStatus(op).join();
    }

    public <T> T executeInTmpSession(Function2<Session, TxControl, CompletableFuture<Result<T>>> op) {
        return getResult(sessionRetryContext
                .supplyResult(session -> op.apply(session, TxControl.serializableRw().setCommitTx(true))));
    }

    private <T> T getResult(CompletableFuture<Result<T>> cf) {
        return cf.join().expect("Cannot execute data query");
    }

    public void commit() {
        try {
            Status result = getCurrent()
                    .orElseThrow(() -> new IllegalStateException("No current transaction"))
                    .transaction.commit(timeoutSettings.getCommitTxSettingsWithTimeout())
                    .get();
            result.expect("Can't commit transaction");
        } catch (Exception e) {
            throw ExceptionUtils.translate(e);
        }
    }

    public void rollback() {
        try {
            getCurrent()
                    .orElseThrow(() -> new IllegalStateException("No current transaction"))
                    .transaction.rollback()
                    .get()
                    .expect("Can't rollback transaction");
        } catch (Exception e) {
            throw ExceptionUtils.translate(e);
        }
    }

    private void setCurrent(Session session, Transaction transaction) {
        if(getCurrent().isPresent()) {
            throw new IllegalStateException("Transaction is already present");
        }
        current.set(new TransactionData(session, transaction));
    }

    private Option<TransactionData> getCurrent() {
        return Option.ofNullable(current.get());
    }

    private void clearCurrent() {
        current.set(null);
    }

    public TableClient getTableClientForTests() {
        return tableClient;
    }

    @AllArgsConstructor
    @Data
    private static class TransactionData {
        private final Session session;
        private final Transaction transaction;
    }
}
