package ru.yandex.travel.spring.tx;

import java.util.HashMap;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.Getter;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;
import org.springframework.dao.support.PersistenceExceptionTranslator;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.TransactionException;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.TransactionSystemException;

@Slf4j
public class ForcedRollbackTxManagerWrapper {

    private final PlatformTransactionManager transactionManager;
    @Getter
    private final PersistenceExceptionTranslator exceptionTranslator;

    private final Object monitor = new Object();
    private final HashMap<Integer, EpochState> epochs;
    private Integer currentEpoch = 0;

    public ForcedRollbackTxManagerWrapper(PlatformTransactionManager transactionManager,
                                          PersistenceExceptionTranslator exceptionTranslator) {
        epochs = new HashMap<>();
        epochs.put(currentEpoch, new EpochState(0, true));
        this.transactionManager = transactionManager;
        this.exceptionTranslator = exceptionTranslator;
    }

    public TransactionStatusWrapper getTransaction(TransactionDefinition transactionDefinition) {
        TransactionStatus status = this.transactionManager.getTransaction(transactionDefinition);
        synchronized (monitor) {
            EpochState state = epochs.get(currentEpoch);
            state.txCount += 1;
            return new TransactionStatusWrapper(status, currentEpoch);
        }
    }

    public void rollbackOnException(TransactionStatusWrapper statusWrapper, Throwable ex) throws TransactionException {
        synchronized (monitor) {
            decrementTxCount(statusWrapper.getEpoch());
        }
        log.debug("Initiating transaction rollback on application exception", ex);
        try {
            this.transactionManager.rollback(statusWrapper.getTransactionStatus());
        } catch (TransactionSystemException ex2) {
            log.error("Application exception overridden by rollback exception", ex);
            ex2.initApplicationException(ex);
            throw ex2;
        } catch (RuntimeException | Error ex2) {
            log.error("Application exception overridden by rollback exception", ex);
            throw ex2;
        }
    }

    public void commitTransaction(TransactionStatusWrapper statusWrapper) {
        synchronized (monitor) {
            EpochState state = decrementTxCount(statusWrapper.getEpoch());
            if (state.isRollbackOnly()) {
                log.debug("Marking current transaction as rollbackOnly");
                statusWrapper.transactionStatus.setRollbackOnly();
            }
        }
        this.transactionManager.commit(statusWrapper.getTransactionStatus());
    }


    public void pauseCommits() {
        synchronized (monitor) {
            EpochState epochState = epochs.get(currentEpoch);
            epochState.setRollbackOnly(true);
        }
    }

    public void resumeCommits() {
        synchronized (monitor) {
            EpochState currentEpochState = epochs.get(currentEpoch);
            if (currentEpochState.isRollbackOnly()) {
                currentEpoch = currentEpoch + 1;
                epochs.put(currentEpoch, new EpochState(0, false));
            }
            log.debug("Resuming commits. Current epoch {} is new {}", currentEpoch, currentEpochState.isRollbackOnly());
        }
    }

    private EpochState decrementTxCount(int epoch) {
        EpochState state = epochs.get(epoch);
        state.txCount -= 1;
        if (state.txCount == 0 && epoch < currentEpoch) {
            epochs.remove(epoch);
        }
        return state;
    }

    @Data
    @AllArgsConstructor
    private static class EpochState {
        private int txCount;
        private boolean rollbackOnly;
    }

    @Value
    public static class TransactionStatusWrapper {
        private TransactionStatus transactionStatus;
        private int epoch;
    }

}
