package ru.yandex.travel.tx.utils;

import java.sql.SQLException;
import java.sql.Savepoint;
import java.util.function.Predicate;
import java.util.function.Supplier;

import javax.persistence.EntityManager;
import javax.sql.DataSource;

import com.google.common.base.Preconditions;
import lombok.extern.slf4j.Slf4j;
import org.springframework.jdbc.datasource.ConnectionHolder;
import org.springframework.orm.jpa.JpaTransactionManager;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.transaction.support.TransactionSynchronizationManager;

@Slf4j
public class SavepointHelper {
    private final JpaTransactionManager transactionManager;
    private final EntityManager entityManager;

    public SavepointHelper(JpaTransactionManager transactionManager, EntityManager entityManager) {
        this.transactionManager = transactionManager;
        this.entityManager = entityManager;
    }

    @Transactional(propagation = Propagation.MANDATORY)
    public <T> T execute(Supplier<T> operation, Predicate<T> rollbackPredicate) {
        // we're using raw access to underlying connectionHolder to use savepoints as they're not
        // supported by jpa transaction manager
        DataSource underlyingDatatsource = ((JpaTransactionManager) transactionManager).getDataSource();
        Preconditions.checkNotNull(underlyingDatatsource, "Transaction manager datasource must be present");
        ConnectionHolder connectionHolder =
                (ConnectionHolder) TransactionSynchronizationManager.getResource(underlyingDatatsource);

        Preconditions.checkNotNull(connectionHolder, "Connection holder must be bound to current transaction");

        Savepoint savepoint;
        try {
            savepoint = connectionHolder.getConnection().setSavepoint("PROCESS_EVENT");
        } catch (SQLException e) {
            log.error("Error setting savepoint", e);
            throw new RuntimeException("Error setting savepoint", e);
        }

        Preconditions.checkNotNull(savepoint, "Savepoint must be initialized");

        T result = operation.get();

        if (rollbackPredicate.test(result)) {
            try {
                connectionHolder.getConnection().rollback(savepoint);
                entityManager.clear();
            } catch (SQLException e) {
                log.error("Error rolling back to savepoint", e);
                throw new RuntimeException("Error rolling back to savepoint", e);
            }
        }
        return result;
    }
}
