package ru.yandex.direct.oneshot.oneshots.mysql2grut

import org.slf4j.LoggerFactory
import ru.yandex.direct.validation.builder.Constraint
import ru.yandex.direct.validation.builder.When
import ru.yandex.direct.validation.constraint.CommonConstraints
import ru.yandex.direct.validation.defect.CommonDefects
import ru.yandex.direct.validation.util.property
import ru.yandex.direct.validation.util.validateObject
import ru.yandex.direct.ytwrapper.client.YtProvider
import ru.yandex.direct.ytwrapper.model.YtCluster
import ru.yandex.direct.ytwrapper.model.YtField
import ru.yandex.direct.ytwrapper.model.YtTable
import ru.yandex.direct.ytwrapper.model.YtTableRow
import java.util.concurrent.atomic.AtomicBoolean

data class Param(
    val ytCluster: YtCluster,
    val tablePath: String,
)

data class ShardedParam(
    val shardColumnName: String?,
    val ytCluster: YtCluster,
    val tablePath: String,
)

class Mysql2GrutYtUtils(private val ytProvider: YtProvider) {
    companion object {
        private val logger = LoggerFactory.getLogger(Mysql2GrutReplicationOneshot::class.java)
    }

    private fun toParam(shardedParam: ShardedParam): Param {
        return Param(shardedParam.ytCluster, shardedParam.tablePath)
    }

    private fun tableHasValidFields(ytCluster: YtCluster, tablePath: String): Boolean {
        val columns = ytProvider.getOperator(ytCluster).getSchema(YtTable(tablePath))
            .map { it.getString("name") }
            .toSet()

        val validFields = InputTableRow.ALL_COLUMNS.map { it.name }
        val hasValidField = columns.intersect(validFields).isNotEmpty()

        if (!hasValidField) {
            logger.error("Input table $tablePath must have one of fields: $validFields")
        }
        return hasValidField
    }

    private fun checkIfInputTableExists(ytCluster: YtCluster, tablePath: String): Boolean {
        return ytProvider.getOperator(ytCluster).exists(YtTable(tablePath))
    }

    private fun checkShardColumnName(param: ShardedParam): Boolean {
        val ytTableRow = YtTableRow(listOf(YtField(param.shardColumnName!!, Long::class.java)))
        val success = AtomicBoolean(false)
        try {
            ytProvider.getOperator(param.ytCluster).readTableByRowRange(YtTable(param.tablePath), { row ->
                val column = row.data.get(param.shardColumnName)
                val isInteger = column.map { it.isIntegerNode && it.intValue() > 0 }.orElse(false)
                success.set(isInteger)
            }, ytTableRow, 0, 1)
        } catch (e: Exception) {
            logger.error("Failed to read shardColumn {}:", param.shardColumnName, e)
        }
        return success.get()
    }

    fun validate(inputData: ShardedParam) =
        validateObject(inputData) {
            validate(toParam(inputData))
            property(inputData::shardColumnName) {
                check(
                    Constraint.fromPredicate({ checkShardColumnName(inputData) }, CommonDefects.inconsistentState()),
                    When.isValidAnd(When.notNull())
                )
            }
        }

    fun validate(inputData: Param) =
        validateObject(inputData) {
            property(inputData::tablePath) {
                check(CommonConstraints.notNull())
                check(
                    Constraint.fromPredicate(
                        { checkIfInputTableExists(inputData.ytCluster, it) },
                        CommonDefects.objectNotFound()
                    )
                )
                check(
                    Constraint.fromPredicate(
                        { tableHasValidFields(inputData.ytCluster, it) },
                        CommonDefects.absentRequiredField()
                    )
                )
            }
        }


}
