package ru.yandex.direct.core.entity.uac.grut

import org.slf4j.LoggerFactory
import org.springframework.context.annotation.Lazy
import org.springframework.stereotype.Component
import ru.yandex.grut.client.GrutClient
import ru.yandex.grut.client.RetriableException
import ru.yandex.grut.proto.transaction_context.TransactionContext.TTransactionContext
import ru.yandex.yt.ytclient.rpc.RpcError

@Component
@Lazy
class GrutTransactionProvider(
    private val grutContext: GrutContext,
    private val grutClient: GrutClient
) {

    companion object {
        private val logger = LoggerFactory.getLogger(GrutTransactionProvider::class.java)
        private val DO_NOTHING = {}
    }

    fun <T> runInTransactionIfNeeded(needTransaction: Boolean, codeToRun: () -> T): T {
        return if (needTransaction) {
            runInTransaction(codeForTransaction = codeToRun)
        } else {
            codeToRun()
        }
    }

    fun <T> runInRetryableTransactionIfNeeded(needTransaction: Boolean, retries: Int, codeToRun: () -> T): T {
        return if (needTransaction) {
            runInRetryableTransaction(retries = retries, codeForTransaction = codeToRun)
        } else {
            codeToRun()
        }
    }

    fun <T> runInTransaction(
        transactionContext: TTransactionContext? = null,
        codeForTransaction: () -> T,
    ): T {
        if (grutContext.transactional()) {
            throw IllegalStateException("Grut transactions must not be nested")
        }
        val grutTransaction = grutClient.startTransaction()
        try {
            grutContext.setTransaction(grutTransaction)
            val result = codeForTransaction()
            grutTransaction.commit(transactionContext)
            return result
        } catch (e: Throwable) {
            logger.warn("Operation failed", e)
            try {
                grutTransaction.abort()
            } catch (e: Throwable) {
                logger.warn("Failed to abort transaction", e)
            }
            throw e
        } finally {
            grutContext.reset()
        }
    }

    fun <T> runInRetryableTransaction(
        retries: Int,
        transactionContext: TTransactionContext? = null,
        codeForTransaction: () -> T
    ): T = runInRetryableTransaction(
        retries = retries,
        sleepBetweenTriesMs = 0,
        transactionContext = transactionContext,
        codeForRevert =  DO_NOTHING,
        codeForTransaction = codeForTransaction
    )

    fun <T> runInRetryableTransaction(
        retries: Int,
        sleepBetweenTriesMs: Long,
        transactionContext: TTransactionContext? = null,
        codeForRevert: () -> Unit,
        codeForTransaction: () -> T,
    ): T {
        if (grutContext.transactional()) {
            throw IllegalStateException("Grut transactions must not be nested")
        }
        return runRetryable(retries, sleepBetweenTriesMs, codeForRevert) {
            runInTransaction(transactionContext = transactionContext, codeForTransaction = codeForTransaction)
        }
    }

    fun <T> runRetryable(retries: Int, codeForRun: () -> T): T {
        return runRetryable(retries = retries, sleepBetweenTriesMs = 0, codeForRevert =  DO_NOTHING, codeForRun = codeForRun)
    }

    fun <T> runRetryable(retries: Int, sleepBetweenTriesMs: Long, codeForRevert: () -> Unit, codeForRun: () -> T): T {
        var remainingRetries = retries
        while (true) {
            try {
                return codeForRun()
            } catch (e: Throwable) {
                remainingRetries--
                if (e is RetriableException && remainingRetries > 0) {
                    logger.warn("Catch retryable exception, there are $remainingRetries retries", e)
                    sleepAfterTry(sleepBetweenTriesMs)
                    continue
                }

                if (e is RetriableException && remainingRetries <= 0) {
                    logger.error("Last retry failed", e)
                }
                try {
                    codeForRevert()
                } catch (revertEx: Throwable) {
                    logger.error("Revert failed", revertEx)
                }
                throw e
            } finally {
                grutContext.reset()
            }
        }
    }

    private fun sleepAfterTry(timeMs: Long) {
        if (timeMs <= 0) {
            return
        }

        logger.info("Sleeping for $timeMs after failed try")
        try {
            Thread.sleep(timeMs)
        } catch (e: InterruptedException) {
            Thread.currentThread().interrupt()
        }
    }
}
