package ru.yandex.chemodan.util.jdbc;

import org.springframework.context.ApplicationContext;
import org.springframework.test.context.TestContext;
import org.springframework.test.context.support.AbstractTestExecutionListener;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.support.DefaultTransactionDefinition;

import ru.yandex.commune.db.shard2.Shard2;
import ru.yandex.misc.lang.Check;

/**
 * @author tolmalev
 */
public class TransactionForEachTestMethodCreator extends AbstractTestExecutionListener {

    private TransactionStatus status;

    @Override
    public void beforeTestMethod(TestContext testContext) throws Exception {
        status = getTransactionManager(testContext)
                .getTransaction(new DefaultTransactionDefinition());
        Check.notNull(status, "TransactionStatus");
    }

    @Override
    public void afterTestMethod(TestContext testContext) throws Exception {
        getTransactionManager(testContext).rollback(status);
    }

    protected final PlatformTransactionManager getTransactionManager(TestContext testContext) {
        ApplicationContext applicationContext = testContext.getApplicationContext();

        Shard2 shard = applicationContext.getBean("defaultShard", Shard2.class);

        return shard.getTransactionManager();
    }
}
