package ru.yandex.direct.oneshot.oneshots.package_strategy_migration

import org.slf4j.LoggerFactory
import org.springframework.stereotype.Component
import ru.yandex.direct.common.db.PpcPropertiesSupport
import ru.yandex.direct.common.db.PpcPropertyNames
import ru.yandex.direct.core.entity.campaign.model.CampaignWithPackageStrategy
import ru.yandex.direct.core.entity.campaign.repository.CampaignTypedRepository
import ru.yandex.direct.dbutil.wrapper.DslContextProvider
import ru.yandex.direct.oneshot.base.YtState
import ru.yandex.direct.oneshot.oneshots.package_strategy_migration.PackageStrategyAutobudgetRestartMigrationOneshot.Companion.InputData
import ru.yandex.direct.oneshot.worker.def.Approvers
import ru.yandex.direct.oneshot.worker.def.Multilaunch
import ru.yandex.direct.oneshot.worker.def.PausedStatusOnFail
import ru.yandex.direct.oneshot.worker.def.Retries
import ru.yandex.direct.oneshot.worker.def.ShardedOneshot
import ru.yandex.direct.validation.result.Defect
import ru.yandex.direct.validation.result.ValidationResult
import ru.yandex.direct.ytwrapper.client.YtProvider
import ru.yandex.direct.ytwrapper.model.YtCluster
import ru.yandex.direct.ytwrapper.model.YtField
import ru.yandex.direct.ytwrapper.model.YtOperator
import ru.yandex.direct.ytwrapper.model.YtTable
import ru.yandex.direct.ytwrapper.model.YtTableRow
import kotlin.math.min

@Component
@Approvers("ssdmitriev", "ruslansd", "ninazhevtyak", "kuvshinov")
@Multilaunch
@Retries(5)
@PausedStatusOnFail
class PackageStrategyAutobudgetRestartMigrationOneshot(
    private val ytProvider: YtProvider,
    private val campaignTypedRepository: CampaignTypedRepository,
    private val dslContextProvider: DslContextProvider,
    private val strategyAutobudgetMigrationService: StrategyAutobudgetMigrationService,
    private val ppcPropertiesSupport: PpcPropertiesSupport
) : ShardedOneshot<InputData, YtState> {
    companion object {
        private const val DEFAULT_CHUNK_SIZE = 500L

        private val logger = LoggerFactory.getLogger(PackageStrategyAutobudgetRestartMigrationOneshot::class.java)

        class Row : YtTableRow(listOf(CID)) {
            companion object {
                private val CID = YtField("cid", Long::class.java)
            }

            val cid: Long get() = valueOf(CID).toLong()
        }

        data class InputData(
            override val ytCluster: String,
            override val tablePath: String,
            val rewriteOnDuplicate: Boolean = false
        ) : BaseInputData
    }

    override fun validate(inputData: InputData): ValidationResult<InputData, Defect<*>> {
        fun tableExists(tablePath: String): Boolean {
            val cluster = YtCluster.parse(inputData.ytCluster.lowercase())
            return ytProvider.getOperator(cluster).exists(YtTable(tablePath))
        }

        val validator = InputDataValidator<InputData> { tableExists(it) }
        return validator.apply(inputData)
    }

    override fun execute(
        inputData: InputData,
        prevState: YtState?,
        shard: Int
    ): YtState? {
        val ytCluster = YtCluster.parse(inputData.ytCluster.lowercase())
        val ytTable = YtTable(inputData.tablePath)
        val ytOperator: YtOperator = ytProvider.getOperator(ytCluster)
        val chunkSize =
            ppcPropertiesSupport.get(PpcPropertyNames.PACKAGE_STRATEGY_AUTOBUDGET_RESTART_MIGRATION_BATCH_SIZE)
                .getOrDefault(
                    DEFAULT_CHUNK_SIZE
                )
        val currentState = prevState ?: firstIteration(shard, ytOperator.readTableRowCount(ytTable))

        return process(
            currentState,
            ytOperator,
            ytTable,
            inputData.rewriteOnDuplicate,
            chunkSize,
            shard
        )
    }

    private fun process(
        state: YtState,
        ytOperator: YtOperator,
        ytTable: YtTable,
        rewriteOnDuplicate: Boolean,
        chunkSize: Long,
        shard: Int
    ): YtState? {
        val rowCount: Long = state.totalRowCount
        val startRow: Long = state.nextRow
        val endRow = min(startRow + chunkSize, rowCount)
        if (startRow >= rowCount) {
            logger.info("Last iteration, last processed row: {}, total rows: {}", startRow, rowCount)
            return null
        }
        val parsedRows = mutableListOf<Long>()
        ytOperator.readTableByRowRange(
            ytTable,
            { parsedRows.add(it.cid) },
            Row(),
            startRow,
            endRow
        )
        logger.info("Ready to process ${parsedRows.size} rows, from [$startRow] to [$endRow]")
        processChunk(shard, parsedRows.toSet(), rewriteOnDuplicate)
        logger.info("Processed ${parsedRows.size} rows, from [$startRow] to [$endRow]")
        return YtState()
            .withNextRowFromYtTable(endRow)
            .withTotalRowCount(rowCount)
    }

    private fun processChunk(
        shard: Int,
        campaignIds: Set<Long>,
        rewriteOnDuplicate: Boolean
    ) {
        val cidToStrategyId = cidToStrategyId(shard, campaignIds)
        if (cidToStrategyId.isNotEmpty()) {
            val relaxTimeInSeconds =
                ppcPropertiesSupport.get(PpcPropertyNames.PACKAGE_STRATEGY_AUTOBUDGET_RESTART_MIGRATION_RELAX_TIME)
                    .getOrDefault(5)
            strategyAutobudgetMigrationService.migrate(shard, cidToStrategyId, rewriteOnDuplicate)
            logger.info("Chunk processed")
            logger.info("Sleep for [$relaxTimeInSeconds] seconds")
            Thread.sleep(relaxTimeInSeconds * 1000L)
        }
    }

    private fun cidToStrategyId(shard: Int, campaignIds: Set<Long>): Map<Long, Long> {
        val campaigns = campaignTypedRepository.getSafely(
            dslContextProvider.ppc(shard),
            campaignIds,
            CampaignWithPackageStrategy::class.java
        )
        return campaigns
            .filter { it.strategyId != null && it.strategyId != 0L }
            .associate { it.id to it.strategyId }
    }

    private fun firstIteration(shard: Int, totalRow: Long): YtState {
        logger.info("First iteration for shard [$shard]")
        return YtState()
            .withNextRowFromYtTable(0L)
            .withTotalRowCount(totalRow)
    }
}
