package ru.yandex.direct.oneshot.oneshots.bsexport.ordertype

import org.jooq.util.mysql.MySQLDSL
import org.slf4j.LoggerFactory
import org.springframework.stereotype.Component
import ru.yandex.direct.dbschema.ppc.Tables.CAMP_ORDER_TYPES
import ru.yandex.direct.dbutil.sharding.ShardHelper
import ru.yandex.direct.dbutil.sharding.ShardKey
import ru.yandex.direct.dbutil.wrapper.DslContextProvider
import ru.yandex.direct.jooqmapperhelper.InsertHelper
import ru.yandex.direct.oneshot.base.SimpleYtOneshot
import ru.yandex.direct.oneshot.base.YtInputData
import ru.yandex.direct.oneshot.base.YtState
import ru.yandex.direct.oneshot.worker.def.Approvers
import ru.yandex.direct.oneshot.worker.def.Multilaunch
import ru.yandex.direct.oneshot.worker.def.PausedStatusOnFail
import ru.yandex.direct.oneshot.worker.def.Retries
import ru.yandex.direct.ytwrapper.client.YtProvider
import ru.yandex.direct.ytwrapper.model.YtCluster
import ru.yandex.direct.ytwrapper.model.YtField
import ru.yandex.direct.ytwrapper.model.YtOperator
import ru.yandex.direct.ytwrapper.model.YtTable
import ru.yandex.direct.ytwrapper.model.YtTableRow

private data class ImportRequest(
    val cid: Long,
    val orderType: Int,
)

private class ImportRequestRow : YtTableRow(listOf(CID, ORDER_TYPE)) {
    val cid: Long
        get() = valueOf(CID) ?: error("cid must be present")

    val orderType: Int
        get() = valueOf(ORDER_TYPE) ?: error("order_type must be present")

    fun toRequest() = ImportRequest(cid, orderType)

    companion object {
        private val CID = YtField("cid", Long::class.java)
        private val ORDER_TYPE = YtField("order_type", Int::class.java)
    }
}

@Component
@Approvers("mspirit", "ppalex", "pema4")
@Multilaunch
@Retries(5)
@PausedStatusOnFail
class OrderTypeImportOneshot(
    ytProvider: YtProvider,
    private val dslContextProvider: DslContextProvider,
    private val shardHelper: ShardHelper,
) : SimpleYtOneshot<YtInputData, YtState?>(ytProvider) {

    override fun execute(
        inputData: YtInputData,
        prevState: YtState?,
    ): YtState? {
        val ytCluster = YtCluster.parse(inputData.ytCluster)
        val ytTable = YtTable(inputData.tablePath)
        val ytOperator = ytProvider.getOperator(ytCluster)

        if (prevState == null) {
            logger.info("First iteration started")
            return createYtState(nextRow = 0, totalRowCount = ytOperator.readTableRowCount(ytTable))
        }

        val rows = readRows(ytOperator, ytTable, from = prevState.nextRow)
        doImport(rows)

        return if (rows.size != READ_CHUNK_SIZE) {
            logger.info("Last iteration finished")
            null
        } else {
            logger.info("Iteration finished, last processed row: ${prevState.nextRow}, total rows: ${prevState.totalRowCount}")
            createYtState(
                nextRow = prevState.nextRow + rows.size,
                totalRowCount = prevState.totalRowCount,
            )
        }
    }

    private fun readRows(
        ytOperator: YtOperator,
        ytTable: YtTable,
        from: Long,
    ): List<ImportRequest> {
        val requests = mutableListOf<ImportRequest>()
        val to = from + READ_CHUNK_SIZE
        ytOperator.readTableByRowRange(ytTable, { requests += it.toRequest() }, ImportRequestRow(), from, to)

        logger.info("Got ${requests.size} requests from $ytTable")
        return requests
    }

    private fun doImport(rows: List<ImportRequest>) {
        shardHelper
            .groupByShard(rows, ShardKey.CID, ImportRequest::cid)
            .chunkedBy(UPDATE_CHUNK_SIZE)
            .forEach { shard, rowsChunk ->
                InsertHelper(dslContextProvider.ppc(shard), CAMP_ORDER_TYPES)
                    .apply {
                        for (row in rowsChunk) {
                            set(CAMP_ORDER_TYPES.CID, row.cid)
                            set(CAMP_ORDER_TYPES.ORDER_TYPE, row.orderType.toLong())
                            newRecord()
                        }
                    }
                    .onDuplicateKeyUpdate()
                    .set(CAMP_ORDER_TYPES.ORDER_TYPE, MySQLDSL.values(CAMP_ORDER_TYPES.ORDER_TYPE))
                    .execute()
                    .also { logger.info("Updated $it rows") }
            }
    }

    companion object {
        private val logger = LoggerFactory.getLogger(OrderTypeImportOneshot::class.java)
        private const val READ_CHUNK_SIZE = 100_000
        private const val UPDATE_CHUNK_SIZE = 1_000
    }
}

private fun createYtState(nextRow: Long, totalRowCount: Long): YtState =
    YtState()
        .withNextRowFromYtTable(nextRow)
        .withTotalRowCount(totalRowCount)
