package ru.yandex.direct.oneshot.oneshots.minus_geo_to_flags

import org.jooq.Configuration
import org.slf4j.LoggerFactory
import org.springframework.stereotype.Component
import ru.yandex.direct.common.db.PpcPropertiesSupport
import ru.yandex.direct.common.db.PpcPropertyNames
import ru.yandex.direct.core.entity.banner.model.BannerFlags
import ru.yandex.direct.core.entity.banner.model.FlagProperty
import ru.yandex.direct.core.entity.bs.resync.queue.model.BsResyncItem
import ru.yandex.direct.core.entity.bs.resync.queue.repository.BsResyncQueueRepository
import ru.yandex.direct.core.entity.moderation.service.receiving.operations.banners.UpdateBannerFlagsOp
import ru.yandex.direct.dbschema.ppc.Tables.BANNERS_MINUS_GEO
import ru.yandex.direct.dbschema.ppc.enums.BannersMinusGeoType
import ru.yandex.direct.dbschema.ppc.enums.BannersStatusmoderate
import ru.yandex.direct.dbschema.ppc.tables.Banners.BANNERS
import ru.yandex.direct.dbutil.sharding.ShardHelper
import ru.yandex.direct.dbutil.sharding.ShardKey
import ru.yandex.direct.dbutil.wrapper.DslContextProvider
import ru.yandex.direct.jooqmapper.JooqMapperUtils
import ru.yandex.direct.oneshot.worker.def.Approvers
import ru.yandex.direct.oneshot.worker.def.Multilaunch
import ru.yandex.direct.oneshot.worker.def.PausedStatusOnFail
import ru.yandex.direct.oneshot.worker.def.Retries
import ru.yandex.direct.oneshot.worker.def.SimpleOneshot
import ru.yandex.direct.validation.builder.Constraint
import ru.yandex.direct.validation.constraint.CommonConstraints
import ru.yandex.direct.validation.defect.CommonDefects
import ru.yandex.direct.validation.util.property
import ru.yandex.direct.validation.util.validateObject
import ru.yandex.direct.ytwrapper.client.YtProvider
import ru.yandex.direct.ytwrapper.model.YtCluster
import ru.yandex.direct.ytwrapper.model.YtField
import ru.yandex.direct.ytwrapper.model.YtTable
import ru.yandex.direct.ytwrapper.model.YtTableRow

data class Param(
    val ytCluster: YtCluster,
    val tablePath: String,
)

data class State(
    val lastRow: Long
)

class InputTableRow : YtTableRow(listOf(BANNER_ID)) {
    companion object {
        private val BANNER_ID = YtField("DirectBannerId", Long::class.java)
    }

    val bannerId: Long?
        get() = valueOf(BANNER_ID)
}

/**
 * Ваншот копирует минус-гео с banners_minus_geo на флаги
 * и переотправляет изменённые баннеры в БК через ленивую очередь
 */
@Component
@Multilaunch
@PausedStatusOnFail
@Retries(5)
@Approvers("mspirit", "elwood", "volodskikh", "ppalex")
class MinusGeoToFlagsMigrationOneshot(
    private val ytProvider: YtProvider,
    private val shardHelper: ShardHelper,
    private val dslContextProvider: DslContextProvider,
    private val bsResyncQueueRepository: BsResyncQueueRepository,
    ppcPropertiesSupport: PpcPropertiesSupport,
) : SimpleOneshot<Param, State?> {

    private val chunkSizeProperty =
        ppcPropertiesSupport.get(PpcPropertyNames.MINUSGEO_FLAGS_ONESHOT_CHUNK_SIZE)

    private val relaxTimeProperty =
        ppcPropertiesSupport.get(PpcPropertyNames.MINUSGEO_FLAGS_ONESHOT_RELAX_TIME_SEC)

    companion object {
        private val logger = LoggerFactory.getLogger(MinusGeoToFlagsMigrationOneshot::class.java)

        private const val DEFAULT_CHUNK_SIZE = 10_000L
        private const val DEFAULT_RELAX_TIME_SEC = 60L

        // Баннеры в этих статусах переотправляем в БК через ленивую очередь
        // Черновики и отклонённые не переотправляем
        private val STATUSES_TO_RESYNC = setOf(
            BannersStatusmoderate.Yes, BannersStatusmoderate.Ready, BannersStatusmoderate.Sending,
            BannersStatusmoderate.Sent
        )
    }

    override fun validate(inputData: Param) = validateObject(inputData) {
        property(inputData::ytCluster) {
            check(CommonConstraints.notNull())
        }
        property(inputData::tablePath) {
            check(CommonConstraints.notNull())
            check(
                Constraint.fromPredicate(
                    { checkIfInputTableExists(inputData.ytCluster, it) },
                    CommonDefects.objectNotFound()
                )
            )
        }
    }

    override fun execute(inputData: Param, prevState: State?): State? {
        logger.info("Start from state=$prevState")
        val startRow = prevState?.lastRow ?: 0
        val chunkSize = chunkSizeProperty.getOrDefault(DEFAULT_CHUNK_SIZE)
        val relaxTimeSec = relaxTimeProperty.getOrDefault(DEFAULT_RELAX_TIME_SEC)
        logger.info("Using chunkSize = $chunkSize, relaxTimeSec = $relaxTimeSec")
        val lastRow = startRow + chunkSize
        val bannerIdsChunk = readBannerIdsFromYtTable(inputData.ytCluster, inputData.tablePath, startRow, lastRow)
        val groupByShard = shardHelper.groupByShard(bannerIdsChunk, ShardKey.BID)
        groupByShard.forEach { shard, bannerIds ->
            logger.info("Processing shard $shard")

            // В транзакции берём лок на баннеры и переносим (копируем) минус-гео на флаги
            // после этого баннеры ставим в очередь ленивой переотправки в БК
            dslContextProvider.ppc(shard).transaction { config: Configuration ->
                val dslContext = config.dsl()
                val bannersFlagsMap: Map<Long, BannerFlags> = dslContext.select(
                    BANNERS.BID, BANNERS.FLAGS,
                ).from(BANNERS)
                    .where(BANNERS.BID.`in`(bannerIds))
                    .forUpdate()
                    .fetchMap(BANNERS.BID) { r ->
                        BannerFlags.fromSource(r.get(BANNERS.FLAGS))
                    }

                val bannersMinusGeoMap = dslContext.select(
                    BANNERS_MINUS_GEO.BID, BANNERS_MINUS_GEO.MINUS_GEO
                )
                    .from(BANNERS_MINUS_GEO)
                    .where(BANNERS_MINUS_GEO.BID.`in`(bannerIds))
                    .and(BANNERS_MINUS_GEO.TYPE.eq(BannersMinusGeoType.current))
                    .fetchMap(BANNERS_MINUS_GEO.BID, BANNERS_MINUS_GEO.MINUS_GEO)

                val bannersStatusModerateMap = dslContext.select(
                    BANNERS.BID, BANNERS.STATUS_MODERATE
                )
                    .from(BANNERS)
                    .where(BANNERS.BID.`in`(bannerIds))
                    .fetchMap(BANNERS.BID, BANNERS.STATUS_MODERATE)

                val bannerFlagsToSaveMap = HashMap<Long, BannerFlags>()

                for ((bid, minusGeo) in bannersMinusGeoMap.entries) {
                    if (!bannersFlagsMap.containsKey(bid)) {
                        // skip banners that doesn't exist
                        continue
                    }
                    if (minusGeo.isNullOrEmpty()) {
                        continue
                    }
                    val minusRegions = minusGeo.split(",").filter { it.isNotEmpty() }.map { it.toLong() }
                    val convertedFlags = convertMinusRegionsToFlags(minusRegions)

                    val flags = bannersFlagsMap[bid] ?: BannerFlags()
                    removeAllMinusRegionFlags(flags)
                    for (flag in convertedFlags) {
                        flags.with(flag, true)
                    }
                    // don't change null to '' if there is no changes
                    if (flags.flags.isEmpty() && !bannersFlagsMap.containsKey(bid)) {
                        continue
                    }
                    bannerFlagsToSaveMap[bid] = flags
                }

                // Обновляем баннеры пачками
                val pairsToUpdate = bannerFlagsToSaveMap.entries.toList()
                for (chunk in pairsToUpdate.chunked(200)) {
                    val flagsByBidMap = chunk.associateBy({ it.key }, { BannerFlags.toSource(it.value) })
                    dslContext.update(BANNERS)
                        .set(
                            BANNERS.FLAGS, JooqMapperUtils.makeCaseStatement(
                                BANNERS.BID,
                                BANNERS.FLAGS,
                                flagsByBidMap
                            )
                        )
                        .where(BANNERS.BID.`in`(flagsByBidMap.keys))
                        .execute()
                }

                val bannerIdsToResync = bannersStatusModerateMap
                    .filterValues { STATUSES_TO_RESYNC.contains(it) }
                    .keys

                if (bannerIdsToResync.isNotEmpty()) {
                    // Ставим баннеры в ленивую очередь на переотправку в БК
                    val bannersCidPidMap = dslContext.select(
                        BANNERS.BID, BANNERS.PID, BANNERS.CID,
                    ).from(BANNERS)
                        .where(BANNERS.BID.`in`(bannerIdsToResync))
                        .fetchMap(BANNERS.BID) { r -> Pair(r.get(BANNERS.CID), r.get(BANNERS.PID)) }

                    val bsResyncItems = bannerFlagsToSaveMap.keys.mapNotNull { bid ->
                        val cid = bannersCidPidMap[bid]?.first
                        val pid = bannersCidPidMap[bid]?.second
                        if (cid == null || pid == null) {
                            // Если cid не найден, значит баннера в базе уже нет, скипаем
                            // (хотя по идее такого не должно происходить, ведь мы их уже залочили в транзакции)
                            null
                        } else {
                            BsResyncItem(89, cid, bid, pid)
                        }
                    }

                    val bidsToResync = bsResyncItems.map { it.bannerId }
                    logger.info("Add ${bidsToResync.size} bids to resync queue: $bidsToResync")
                    bsResyncQueueRepository.addToResync(dslContext, bsResyncItems)
                } else {
                    logger.info("No bids in statuses to resync in this chunk")
                }
            }
        }

        if (bannerIdsChunk.isNotEmpty()) {
            logger.info("Iteration finished, sleep for $relaxTimeSec seconds")
            Thread.sleep(relaxTimeSec * 1000)
        } else {
            logger.info("No handled objects on iteration")
        }

        return if (bannerIdsChunk.size == chunkSize.toInt()) State(lastRow) else null
    }

    private fun checkIfInputTableExists(ytCluster: YtCluster, tablePath: String): Boolean {
        return ytProvider.getOperator(ytCluster).exists(YtTable(tablePath))
    }

    private fun convertMinusRegionsToFlags(minusRegions: List<Long>): Set<FlagProperty<Boolean>> {
        val flags = HashSet<FlagProperty<Boolean>>()
        for (minusRegion in minusRegions) {
            if (!UpdateBannerFlagsOp.MINUS_REGION_FLAGS.containsKey(minusRegion)) {
                logger.warn("Unknown minus region value = {}", minusRegion)
            } else {
                flags.add(UpdateBannerFlagsOp.MINUS_REGION_FLAGS[minusRegion]!!)
            }
        }
        return flags
    }

    private fun removeAllMinusRegionFlags(bannerFlags: BannerFlags) {
        bannerFlags.remove(BannerFlags.MINUS_REGION_RU)
        bannerFlags.remove(BannerFlags.MINUS_REGION_KZ)
        bannerFlags.remove(BannerFlags.MINUS_REGION_UA)
        bannerFlags.remove(BannerFlags.MINUS_REGION_RB)
        bannerFlags.remove(BannerFlags.MINUS_REGION_TR)
        bannerFlags.remove(BannerFlags.MINUS_REGION_UZ)
    }

    private fun readBannerIdsFromYtTable(
        ytCluster: YtCluster,
        tablePath: String,
        startRow: Long,
        lastRow: Long
    ): List<Long> {
        val bannerIds = mutableListOf<Long>()
        ytProvider.getOperator(ytCluster)
            .readTableByRowRange(
                YtTable(tablePath),
                { bannerIds.add(it.bannerId!!) }, InputTableRow(), startRow, lastRow
            )
        return bannerIds
    }
}
