package ru.yandex.direct.oneshot.oneshots.fill_bids_phraseid_associate

import java.math.BigInteger
import java.time.LocalDateTime
import org.slf4j.LoggerFactory
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.stereotype.Component
import ru.yandex.direct.common.db.PpcPropertiesSupport
import ru.yandex.direct.common.db.PpcPropertyNames
import ru.yandex.direct.dbutil.sharding.ShardHelper
import ru.yandex.direct.dbutil.sharding.ShardKey
import ru.yandex.direct.oneshot.oneshots.fill_bids_phraseid_associate.repository.OneshotFillBidsPhraseIdAssociateRepository
import ru.yandex.direct.oneshot.worker.def.Approvers
import ru.yandex.direct.oneshot.worker.def.Multilaunch
import ru.yandex.direct.oneshot.worker.def.PausedStatusOnFail
import ru.yandex.direct.oneshot.worker.def.Retries
import ru.yandex.direct.oneshot.worker.def.SimpleOneshot
import ru.yandex.direct.validation.builder.Constraint
import ru.yandex.direct.validation.builder.When
import ru.yandex.direct.validation.constraint.CommonConstraints
import ru.yandex.direct.validation.constraint.NumberConstraints
import ru.yandex.direct.validation.defect.CommonDefects
import ru.yandex.direct.validation.util.listProperty
import ru.yandex.direct.validation.util.property
import ru.yandex.direct.validation.util.validateObject
import ru.yandex.direct.ytwrapper.client.YtProvider
import ru.yandex.direct.ytwrapper.model.YtCluster
import ru.yandex.direct.ytwrapper.model.YtTable

data class InputData(
    val ytCluster: YtCluster,
    val tablePath: String,

    // Добавляем только в шардах из этой коллекции. Либо во всех, если null или empty
    val shards: List<Int>?,
)

data class State(
    var lastRow: Long,
    var insertCountByShard: MutableMap<Int, Int>,
)

data class BidsPhraseIdAssociate(
    val campaignId: Long,
    val adGroupId: Long,
    val keywordId: Long,
    val bsPhraseID: BigInteger,
    val logTime: LocalDateTime,
)

/**
 * Ваншот для добавления в таблицу bids_phraseid_associate данных из переданной yt таблицы
 *
 * В yt таблице должны быть следующие колонки: cid, pid, bids_id, PhraseID, log_time
 */
@Component
@Multilaunch
@Approvers("mspirit", "dimitrovsd", "khuzinazat", "a-dubov", "ppalex", "maxlog", "gerdler")
@Retries(5)
@PausedStatusOnFail
class FillBidsPhraseIdAssociateOneshot : SimpleOneshot<InputData, State?> {
    companion object {
        private val logger = LoggerFactory.getLogger(FillBidsPhraseIdAssociateOneshot::class.java)
        private const val DEFAULT_CHUNK_SIZE = 2_000L
        private const val DEFAULT_RELAX_TIME_BETWEEN_ITERATIONS_IN_SEC = 5L
    }

    private val ytProvider: YtProvider
    private val oneshotFillBidsPhraseIdAssociateRepository: OneshotFillBidsPhraseIdAssociateRepository
    private val ppcPropertiesSupport: PpcPropertiesSupport
    private val shardHelper: ShardHelper
    private val chunkSize: Long

    @Autowired
    constructor(
        ytProvider: YtProvider,
        oneshotFillBidsPhraseIdAssociateRepository: OneshotFillBidsPhraseIdAssociateRepository,
        ppcPropertiesSupport: PpcPropertiesSupport,
        shardHelper: ShardHelper,
    ) {
        this.ytProvider = ytProvider
        this.oneshotFillBidsPhraseIdAssociateRepository = oneshotFillBidsPhraseIdAssociateRepository
        this.ppcPropertiesSupport = ppcPropertiesSupport
        this.shardHelper = shardHelper
        this.chunkSize = DEFAULT_CHUNK_SIZE
    }

    /**
     * Используется для тестов
     */
    constructor(
        ytProvider: YtProvider,
        oneshotFillBidsPhraseIdAssociateRepository: OneshotFillBidsPhraseIdAssociateRepository,
        ppcPropertiesSupport: PpcPropertiesSupport,
        shardHelper: ShardHelper,
        chunkSize: Long,
    ) {
        this.ytProvider = ytProvider
        this.oneshotFillBidsPhraseIdAssociateRepository = oneshotFillBidsPhraseIdAssociateRepository
        this.ppcPropertiesSupport = ppcPropertiesSupport
        this.shardHelper = shardHelper
        this.chunkSize = chunkSize
    }

    override fun validate(inputData: InputData) = validateObject(inputData) {
        property(inputData::ytCluster) {
            check(CommonConstraints.notNull())
        }
        property(inputData::tablePath) {
            check(CommonConstraints.notNull())
            check(
                Constraint.fromPredicate(
                    {
                        ytProvider.getOperator(inputData.ytCluster).exists(YtTable(it))
                    },
                    CommonDefects.objectNotFound()
                ), When.isValid()
            )
        }
        listProperty(inputData::shards) {
            checkEach(CommonConstraints.notNull())
            checkEach(NumberConstraints.inRange(1, shardHelper.dbShards().size), When.isValid())
        }
    }

    override fun execute(
        inputData: InputData,
        prevState: State?
    ): State? {
        val state = prevState ?: State(0, mutableMapOf())

        val startRow = state.lastRow
        val lastRow = startRow + chunkSize
        logger.info("Iteration starts with rows [$startRow, $lastRow)")

        val relaxTimeProperty = ppcPropertiesSupport.get(PpcPropertyNames.INSERT_TO_BIDS_PHRASEID_ASSOCIATE_RELAX_TIME)
        val relaxTimeBetweenIterations = relaxTimeProperty.getOrDefault(DEFAULT_RELAX_TIME_BETWEEN_ITERATIONS_IN_SEC)

        val bidsPhraseIdAssociatesChunk = oneshotFillBidsPhraseIdAssociateRepository
            .getBidsPhraseIdAssociatesFromYtTable(inputData.ytCluster, inputData.tablePath, startRow, lastRow)

        if (bidsPhraseIdAssociatesChunk.isNotEmpty()) {
            val bidsPhraseIdAssociatesByShard =
                getBidsPhraseIdAssociateByShard(inputData.shards?.toSet(), bidsPhraseIdAssociatesChunk)

            bidsPhraseIdAssociatesByShard.forEach { (shard, bidsPhraseIdAssociates) ->
                val amountOfInserted = insertBidsPhraseIdAssociates(shard, bidsPhraseIdAssociates)
                state.insertCountByShard[shard] = (state.insertCountByShard[shard] ?: 0) + amountOfInserted
            }
        }

        return if (bidsPhraseIdAssociatesChunk.size < chunkSize) {
            logger.info("Work completed! Total inserts ${state.insertCountByShard}")
            null
        } else {
            logger.info("Iteration finished, sleep for $relaxTimeBetweenIterations seconds")
            Thread.sleep(relaxTimeBetweenIterations * 1000)
            State(lastRow, state.insertCountByShard)
        }
    }

    private fun insertBidsPhraseIdAssociates(
        shard: Int,
        bidsPhraseIdAssociates: List<BidsPhraseIdAssociate>,
    ): Int {
        val amountOfInserted = oneshotFillBidsPhraseIdAssociateRepository
            .insertBidsPhraseIdAssociates(shard, bidsPhraseIdAssociates)
        logger.info("shard=$shard: got ${bidsPhraseIdAssociates.size} rows, $amountOfInserted were inserted")
        return amountOfInserted
    }

    private fun getBidsPhraseIdAssociateByShard(
        shards: Set<Int>?,
        bidsPhraseIdAssociates: Collection<BidsPhraseIdAssociate>,
    ): Map<Int, List<BidsPhraseIdAssociate>> {
        val bidsPhraseIdAssociateByCid = bidsPhraseIdAssociates
            .groupBy { it.campaignId }

        return shardHelper.groupByShard(bidsPhraseIdAssociateByCid.keys, ShardKey.CID)
            .shardedDataMap
            .filter { shards.isNullOrEmpty() || shards.contains(it.key) }
            .mapValues {
                it.value
                    .map { cid -> bidsPhraseIdAssociateByCid[cid]!! }
                    .flatten()
            }
    }
}

