package ru.yandex.direct.grid.core.entity.strategy.repository

import org.jooq.Condition
import org.jooq.Field
import org.jooq.impl.DSL
import org.jooq.impl.DSL.row
import org.springframework.stereotype.Service
import ru.yandex.direct.core.entity.container.LocalDateRange
import ru.yandex.direct.grid.core.entity.model.GdiEntityStats
import ru.yandex.direct.grid.core.entity.model.GdiGoalConversion
import ru.yandex.direct.grid.core.util.stats.GridStatNew
import ru.yandex.direct.grid.core.util.stats.completestat.DirectGridStatData
import ru.yandex.direct.grid.core.util.stats.completestat.GridStatTableData
import ru.yandex.direct.grid.core.util.stats.goalstat.DirectGoalGridStatData
import ru.yandex.direct.grid.core.util.stats.goalstat.GridGoalStatTableData
import ru.yandex.direct.grid.core.util.stats.objstat.GridObjectStatTableData
import ru.yandex.direct.grid.core.util.yt.YtDynamicSupport
import ru.yandex.direct.utils.DateTimeUtils.MSK
import ru.yandex.direct.ytwrapper.dynamic.dsl.YtDSL
import ru.yandex.direct.ytwrapper.dynamic.dsl.YtQueryUtil.DECIMAL_MULT
import java.math.BigDecimal
import java.time.LocalDate

@Service
class GridPackageStrategyYtRepository(val ytSupport: YtDynamicSupport) {
    private val goalsStatTable = DirectGoalGridStatData()
    private val statTable = DirectGridStatData()

    private val gridStat = GridStatNew(DirectGridStatData.INSTANCE)

    data class Filter(
        val filteringTuples: List<FilteringTuple>,
        val goalIdsByStrategyId: Map<Long, List<Long>>? = null
    ) {
        fun <T : GridObjectStatTableData> apply(statTable: T): Condition {
            return when (statTable) {
                is GridGoalStatTableData<*> -> applyToGridGoalStatTableData(statTable)
                is GridStatTableData<*, *> -> applyToGridStatTableData(statTable)
                else ->
                    throw IllegalArgumentException("Unexpected stat table")
            }
        }

        private fun applyToGridStatTableData(statTable: GridStatTableData<*, *>): Condition {
            val rowIn = filteringTuples.map {
                row(it.cid, it.updateTime, it.effecitveStrategyId)
            }
            return row(
                statTable.campaignId(),
                statTable.updateTime(),
                statTable.effectiveAutobudgetStrategyId()
            )
                .`in`(rowIn)
        }

        private fun applyToGridGoalStatTableData(statTable: GridGoalStatTableData<*>): Condition {
            // Значит, что запросили с фильтрацией по целям 
            return if (goalIdsByStrategyId != null) {
                val rowIn = filteringTuples.mapNotNull { tuple ->
                    goalIdsByStrategyId[tuple.effecitveStrategyId]?.let { goalIds ->
                        goalIds.map { goalId ->
                            row(tuple.cid, tuple.updateTime, tuple.effecitveStrategyId, goalId)
                        }
                    }
                }.flatten()

                row(
                    statTable.campaignId(),
                    statTable.updateTime(),
                    statTable.effectiveAutobudgetStrategyId(),
                    statTable.goalId()
                )
                    .`in`(rowIn)
            } else {
                val rowIn = filteringTuples.map {
                    row(it.cid, it.updateTime, it.effecitveStrategyId)
                }
                row(
                    statTable.campaignId(),
                    statTable.updateTime(),
                    statTable.effectiveAutobudgetStrategyId()
                )
                    .`in`(rowIn)
            }
        }
    }

    fun strategyGoalsConversions(
        filter: Filter,
        onlyStrategyGoals: Boolean,
        availableGoals: Set<Long>? = null
    ): Map<Long, List<GdiGoalConversion>> {
        val goalId = goalsStatTable.goalId().`as`(goalsStatTable.goalId().name)

        val goalsNumTotal =
            goalsStatTotal(goalsStatTable.goalsNum(), onlyStrategyGoals, availableGoals)
                .`as`(goalsStatTable.goalsNum().name)

        val priceCurTotal =
            goalsStatTotal(goalsStatTable.priceCur().divide(DECIMAL_MULT), onlyStrategyGoals, availableGoals)
                .`as`(goalsStatTable.priceCur().name)

        val query = YtDSL.ytContext()
            .select(
                goalsStatTable.effectiveAutobudgetStrategyId(),
                goalId,
                goalsNumTotal,
                priceCurTotal
            )
            .from(goalsStatTable.table())
            .where(filter.apply(goalsStatTable))
            .groupBy(goalsStatTable.effectiveAutobudgetStrategyId(), goalId)
        return ytSupport.selectRows(query)
            .yTreeRows
            .groupBy({ it.getOrThrow(goalsStatTable.effectiveAutobudgetStrategyId().name).longValue() }) {
                gridStat.extractGoalConversionWithRevenue(it)
            }.mapValues { it.value.filterNotNull() }
    }

    fun strategyEntityStats(
        filter: Filter
    ): Map<Long, GdiEntityStats> {
        val effectiveStrategyId = gridStat.tableData.effectiveAutobudgetStrategyId()

        val query = YtDSL.ytContext()
            .select(effectiveStrategyId)
            .select(gridStat.statSelectFields)
            .from(gridStat.tableData.table())
            .where(filter.apply(statTable))
            .groupBy(effectiveStrategyId)

        return ytSupport.selectRows(query)
            .yTreeRows
            .associateBy({ it.getOrThrow(effectiveStrategyId.name).longValue() }) {
                gridStat.extractStatsEntry(it)
            }
    }

    data class FilteringTuple(val cid: Long, val updateTime: Long, val effecitveStrategyId: Long)

    fun getFilteringTuples(
        cidToLocalDateRange: Map<Long, LocalDateRange>,
        strategyIds: Set<Long>
    ): List<FilteringTuple> {
        val campaignIdColumn = gridStat.tableData.effectiveCampaignId()
        val strategyIdColumn = gridStat.tableData.effectiveAutobudgetStrategyId()
        val updateTime = gridStat.tableData.updateTime().`as`(statTable.updateTime().name)

        val cidToLocalDates = cidToLocalDateRange
            .map { (cid, localDateRange) ->
                getByDays(localDateRange)
                    .map { row(cid, it.atStartOfDay(MSK).toEpochSecond()) }
            }.flatten()

        val query = YtDSL.ytContext()
            .select(campaignIdColumn, updateTime, strategyIdColumn)
            .from(gridStat.tableData.table())
            .where(row(gridStat.tableData.campaignId(), updateTime).`in`(cidToLocalDates))

        return ytSupport.selectRows(query)
            .yTreeRows
            .map {
                FilteringTuple(
                    cid = it.getOrThrow(campaignIdColumn.name).longValue(),
                    updateTime = it.getOrThrow(updateTime.name).longValue(),
                    effecitveStrategyId = it.getOrThrow(strategyIdColumn.name).longValue(),
                )
            }
            .filter { strategyIds.contains(it.effecitveStrategyId) }
    }

    private fun getByDays(localDateRange: LocalDateRange): List<LocalDate> =
        generateSequence(localDateRange.fromInclusive) { it.plusDays(1) }
            .takeWhile { !it.isAfter(localDateRange.toInclusive) }
            .toList()

    private fun goalsStatTotal(
        field: Field<Long>,
        onlyStrategyGoals: Boolean,
        availableGoals: Set<Long>?
    ): Field<BigDecimal> {
        val goalCondition = if (availableGoals != null) {
            goalsStatTable.goalId().`in`(availableGoals)
        } else {
            DSL.noCondition()
        }
        val strategyGoalCondition =
            goalCondition.and(goalsStatTable.campaignGoalType().eq(GridStatNew.STRATEGY_GOAL_TYPE))
        val meaningfulGoalCondition =
            goalCondition.and(goalsStatTable.campaignGoalType().eq(GridStatNew.MEANINGFUL_GOAL_TYPE))

        val goalsStatSum = if (onlyStrategyGoals) {
            field.sumIf(strategyGoalCondition)
                .`if`({ it.greaterOrEqual(BigDecimal.ZERO) }, field.sumIf(meaningfulGoalCondition))
        } else {
            field.sumIf(goalCondition)
        }
        return goalsStatSum.ifNull(BigDecimal.ZERO)
    }

    private fun <T : Number> Field<T>.sumIf(condition: Condition): Field<BigDecimal> =
        YtDSL.sumIf(this, condition)

    private fun <T> Field<T>.ifNull(value: T): Field<T> =
        YtDSL.ytIfNull(this, value)

    private fun <T> Field<T>.`if`(cond: (Field<T>) -> Condition, value: Field<T>): Field<T> =
        YtDSL.ytIf(cond(this), this, value)
}
