package ru.yandex.direct.oneshot.oneshots

import com.google.gson.JsonDeserializationContext
import com.google.gson.JsonDeserializer
import com.google.gson.JsonElement
import org.slf4j.LoggerFactory
import ru.yandex.direct.core.entity.campaign.model.BaseCampaign
import ru.yandex.direct.core.entity.campaign.repository.CampaignTypedRepository
import ru.yandex.direct.oneshot.oneshots.CampaignMigrationBaseOneshot.Companion.State
import ru.yandex.direct.oneshot.util.GsonUtils
import ru.yandex.direct.oneshot.worker.def.ShardedOneshot
import ru.yandex.direct.validation.result.Defect
import ru.yandex.direct.validation.result.ValidationResult
import java.lang.reflect.Type

abstract class CampaignMigrationBaseOneshot<Input>(
    val campaignTypedRepository: CampaignTypedRepository
) : ShardedOneshot<Input, State> {

    companion object {

        data class Position(val lastCampaignId: Long)

        sealed interface State

        data class LastPosition(val position: Position) : State

        object Finished : State

        //TODO пока это так немного костыльно в качестве быстрофикса, лучше сделаем в рамках https://st.yandex-team.ru/DIRECT-154490
        class GsonStateDeserializer : JsonDeserializer<State> {
            override fun deserialize(json: JsonElement, typeOfT: Type, context: JsonDeserializationContext): State =
                if (json.isJsonObject && json.asJsonObject.has("position")) {
                    GsonUtils.GSON.fromJson(json, LastPosition::class.java)
                } else {
                    Finished
                }
        }

        private const val DEFAULT_BATCH_LIMIT = 10_000
    }

    private val log = LoggerFactory.getLogger(this.javaClass)

    override fun validate(inputData: Input): ValidationResult<Input, Defect<Any>> =
        ValidationResult.success(inputData)

    override fun execute(inputData: Input, prevState: State?, shard: Int): State? =
        prevState?.let {
            when (it) {
                Finished -> {
                    log.info("Finish last iteration on shard $shard")
                    return null
                }
                is LastPosition -> {
                    log.info("Start iteration at position ${it.position} on shard $shard")
                    iteration(it.position, shard)
                }
            }
        } ?: LastPosition(Position(0))

    private fun iteration(position: Position, shard: Int): State {
        val batchLimit = batchLimit()
        val batch = campaignTypedRepository.getSafelyCampaignsForClassesAndIdGreaterThan(
            shard,
            position.lastCampaignId,
            classes(),
            batchLimit
        )
        log.info("Selected [${batch.size}] campaigns to process on shard [$shard]")
        process(shard, batch)
        log.info("There are [${batch.size}] campaigns processed on shard [$shard]")
        val lastPosition = batch.maxByOrNull { it.id }?.let { LastPosition(Position(it.id)) }
        return lastPosition ?: Finished
    }

    protected abstract fun classes(): Set<Class<out BaseCampaign>>
    protected abstract fun process(shard: Int, campaigns: List<BaseCampaign>)
    protected open fun batchLimit(): Int = DEFAULT_BATCH_LIMIT
}
