package ru.yandex.direct.oneshot.oneshots.bsexport.autobudget.restart.import

import org.slf4j.LoggerFactory.getLogger
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.stereotype.Component
import ru.yandex.direct.autobudget.restart.service.Reason
import ru.yandex.direct.core.entity.campaign.repository.CampaignRepository
import ru.yandex.direct.dbschema.ppc.Tables.CAMP_AUTOBUDGET_RESTART
import ru.yandex.direct.dbutil.sharding.ShardHelper
import ru.yandex.direct.dbutil.sharding.ShardKey
import ru.yandex.direct.dbutil.wrapper.DslContextProvider
import ru.yandex.direct.jooqmapper.JooqMapperUtils.makeCaseStatement
import ru.yandex.direct.oneshot.worker.def.Approvers
import ru.yandex.direct.oneshot.worker.def.Multilaunch
import ru.yandex.direct.oneshot.worker.def.PausedStatusOnFail
import ru.yandex.direct.oneshot.worker.def.SimpleOneshot
import ru.yandex.direct.validation.builder.Constraint
import ru.yandex.direct.validation.constraint.CollectionConstraints
import ru.yandex.direct.validation.constraint.CommonConstraints
import ru.yandex.direct.validation.constraint.NumberConstraints
import ru.yandex.direct.validation.constraint.StringConstraints
import ru.yandex.direct.validation.defect.CommonDefects
import ru.yandex.direct.validation.util.listProperty
import ru.yandex.direct.validation.util.property
import ru.yandex.direct.validation.util.validateObject
import ru.yandex.direct.ytwrapper.client.YtProvider
import ru.yandex.direct.ytwrapper.model.YtCluster
import ru.yandex.direct.ytwrapper.model.YtTable
import java.time.Instant
import java.time.LocalDateTime
import java.time.ZoneId

data class Param(
    val ytCluster: YtCluster?,
    val tablePath: String?,

    val orderIdList: List<Long>?,
    val orderIdMicroPercentFrom: Long?,
    val orderIdMicroPercentTo: Long?,

    val borderTime: LocalDateTime,
)

data class State(
    val lastOrderId: Long
)

data class TableRow(
    val orderId: Long,
    val restartTime: LocalDateTime,
    val softRestartTime: LocalDateTime,
)

private fun Long.epochToLocalDateTime() =
    Instant.ofEpochSecond(this).atZone(ZoneId.systemDefault()).toLocalDateTime()


@Component
@Approvers("zhur", "hmepas", "lena-san", "mspirit")
@Multilaunch
@PausedStatusOnFail
class AutobudgetRestartImportOneshot @Autowired constructor(
    private val dsl: DslContextProvider,
    private val campaignRepository: CampaignRepository,
    private val ytProvider: YtProvider,
    private val shardHelper: ShardHelper,
) : SimpleOneshot<Param, State?> {

    companion object {
        private val logger = getLogger(AutobudgetRestartImportOneshot::class.java)
        private const val READ_CHUNK_SIZE = 100_000
        private const val UPDATE_CHUNK_SIZE = 1_000
    }

    override fun validate(inputData: Param) =
        validateObject(inputData) {
            property(inputData::ytCluster)
                .check(CommonConstraints.notNull())
            property(inputData::tablePath) {
                check(CommonConstraints.notNull())
                check(StringConstraints.matchPattern("^//[a-zA-Z0-9_-]+(/[a-zA-Z0-9_-]+)+$"))
                if (inputData.ytCluster != null) {
                    check(
                        Constraint.fromPredicate(
                            { ytProvider.getOperator(inputData.ytCluster).exists(YtTable(it)) },
                            CommonDefects.objectNotFound()
                        )
                    )
                }
            }
            property(inputData::borderTime)
                .check(CommonConstraints.notNull())
            if (inputData.orderIdList != null) {
                property(inputData::orderIdMicroPercentFrom)
                    .check(CommonConstraints.isNull())
                property(inputData::orderIdMicroPercentTo)
                    .check(CommonConstraints.isNull())
                listProperty(inputData::orderIdList)
                    .check(CollectionConstraints.notEmptyCollection())
                    .checkEach(NumberConstraints.greaterThan(0L))
            } else {
                property(inputData::orderIdMicroPercentFrom)
                    .check(CommonConstraints.notNull())
                    .check(NumberConstraints.notLessThan(0L))
                    .check(NumberConstraints.lessThan(1_000_000L))
                property(inputData::orderIdMicroPercentTo)
                    .check(CommonConstraints.notNull())
                    .check(NumberConstraints.notLessThan(0L))
                    .check(NumberConstraints.lessThan(1_000_000L))
            }
        }

    override fun execute(inputData: Param, prevState: State?): State? {
        val lastOrderId = prevState?.lastOrderId ?: 0L
        logger.info("Start from orderId=$lastOrderId")

        val rows = getRowsFromYt(inputData, lastOrderId)

        updateRestartTimes(rows, inputData)

        return if (rows.size < READ_CHUNK_SIZE) null else State(rows.last().orderId)
    }


    private fun updateRestartTimes(
        rows: List<TableRow>,
        inputData: Param
    ) {
        shardHelper.groupByShard(rows, ShardKey.ORDER_ID, TableRow::orderId)
            .chunkedBy(UPDATE_CHUNK_SIZE)
            .forEach { shard, chunk ->
                logger.info("Start process chunk with ${chunk.size} rows for shard ${shard}")
                val oid2cid = campaignRepository.getCidsForOrderIds(shard, chunk.map { it.orderId })
                val t = CAMP_AUTOBUDGET_RESTART.`as`("abr")

                val todo = chunk.filter { it.orderId in oid2cid }
                dsl.ppc(shard).update(t)
                    .set(
                        t.RESTART_TIME,
                        makeCaseStatement(
                            t.CID,
                            t.RESTART_TIME,
                            todo.associate { oid2cid[it.orderId] to it.restartTime }
                        )
                    )
                    .set(
                        t.SOFT_RESTART_TIME,
                        makeCaseStatement(
                            t.CID,
                            t.SOFT_RESTART_TIME,
                            todo.associate { oid2cid[it.orderId] to it.softRestartTime }
                        )
                    )
                    .set(
                        t.RESTART_REASON,
                        Reason.BS_RESTART.name
                    )
                    .where(t.CID.`in`(todo.map { oid2cid[it.orderId] }))
                    .and(t.RESTART_TIME.lessThan(inputData.borderTime))
                    .execute()
                    .also { logger.info("affected $it rows") }
            }
    }


    private fun getRowsFromYt(inputData: Param, lastOrderId: Long): List<TableRow> {
        val orderIdConditionSql = if (inputData.orderIdList != null) {
            "OrderID in (${inputData.orderIdList.joinToString(",")})"
        } else {
            """OrderID % 1000000 >= ${inputData.orderIdMicroPercentFrom!!}
                | AND OrderID % 1000000 <= ${inputData.orderIdMicroPercentTo!!}""".trimMargin()
        }
        return ytProvider.getDynamicOperator(inputData.ytCluster!!)
            .selectRows(
                """
                            OrderID, StartTime, LastUpdateTime
                            FROM [${inputData.tablePath!!.replace("]", "\\]")}]
                            WHERE OrderID > ${lastOrderId}
                            AND StartTime > 0
                            AND LastUpdateTime > 0
                            AND ${orderIdConditionSql}
                            ORDER BY OrderID
                            LIMIT ${READ_CHUNK_SIZE}
                            """
            )
            .yTreeRows
            .map {
                TableRow(
                    orderId = it.getLong("OrderID"),
                    restartTime = it.getLong("StartTime").epochToLocalDateTime(),
                    softRestartTime = it.getLong("LastUpdateTime").epochToLocalDateTime(),
                )
            }
    }
}
