package ru.yandex.intranet.d.util.bucket

import kotlin.math.abs
import kotlin.math.max
import kotlin.math.min

class BucketState private constructor(
    private var configuration: BucketConfiguration,
) {

    lateinit var state: LongArray

    constructor(bucketConfiguration: BucketConfiguration, currentTime: Nanos): this(bucketConfiguration) {
        state = LongArray(3 * bucketConfiguration.bandwidths.size)
        bucketConfiguration.bandwidths.forEachIndexed { i, bandwidth ->
            state[i * 3 + 1] = getInitialTokens(bandwidth, currentTime)
            state[i * 3] = getLastRefillTime(bandwidth, currentTime)
        }
    }

    private constructor(other: BucketState): this(other.configuration) {
        state = other.state.clone()
    }

    fun copy(): BucketState {
        return BucketState(this)
    }

    fun copyFrom(source: BucketState) {
        if (source.configuration === configuration) {
            source.state.copyInto(state)
        } else {
            this.configuration = source.configuration
            this.state = source.state.clone()
        }
    }

    fun refillAll(currentTime: Nanos) {
        configuration.bandwidths.forEachIndexed { i, bandwidth -> refill(i, bandwidth, currentTime) }
    }

    fun getAvailableTokens(): Tokens {
        var availableTokens = state[1]
        for (i in 1 until configuration.bandwidths.size) {
            availableTokens = min(availableTokens, state[i * 3 + 1])
        }
        return availableTokens
    }

    fun consume(tokens: Tokens) {
        for (i in 0 until configuration.bandwidths.size) {
            state[i * 3 + 1] -= tokens
        }
    }

    fun getDelayToConsumption(tokens: Tokens, currentTime: Nanos, tokensLimitedByCapacity: Boolean): Nanos? {
        val bandwidths = configuration.bandwidths
        var delayToConsumption = getDelayToConsumption(0, bandwidths[0], tokens, currentTime, tokensLimitedByCapacity)
        for (i in 1 until bandwidths.size) {
            val delay = getDelayToConsumption(i, bandwidths[i], tokens, currentTime, tokensLimitedByCapacity)
            if (delay == null) {
                delayToConsumption = null
            } else if (delayToConsumption != null && delay > delayToConsumption) {
                delayToConsumption = delay
            }
        }
        return delayToConsumption
    }

    fun reset() {
        val bandwidths = configuration.bandwidths
        bandwidths.forEachIndexed { index, bandwidth ->
            state[index * 3 + 1] = bandwidth.capacity
            state[index * 3 + 2] = 0
        }
    }

    fun addTokens(tokens: Tokens) {
        configuration.bandwidths.forEachIndexed { index, bandwidth -> addTokens(index, bandwidth, tokens) }
    }

    fun forceAddTokens(tokens: Tokens) {
        configuration.bandwidths.indices.forEach { index -> forceAddTokens(index, tokens) }
    }

    private fun addTokens(index: Int, bandwidth: Bandwidth, tokens: Tokens) {
        val currentSize = state[index * 3 + 1]
        val newSize = currentSize + tokens
        if (newSize >= bandwidth.capacity) {
            state[index * 3 + 1] = bandwidth.capacity
            state[index * 3 + 2] = 0
        } else if (newSize < currentSize) {
            state[index * 3 + 1] = bandwidth.capacity
            state[index * 3 + 2] = 0
        } else {
            state[index * 3 + 1] = newSize
        }
    }

    private fun forceAddTokens(index: Int, tokens: Tokens) {
        val currentSize: Long = state[index * 3 + 1]
        val newSize = currentSize + tokens
        if (newSize < currentSize) {
            state[index * 3 + 1] = Long.MAX_VALUE
            state[index * 3 + 2] = 0
        } else {
            state[index * 3 + 1] = newSize
        }
    }

    private fun getInitialTokens(bandwidth: Bandwidth, currentTime: Nanos): Tokens {
        if (!bandwidth.refill.adaptiveInitialTokens) {
            return bandwidth.initialTokens
        }
        val timeOfFirstRefillNanos = (bandwidth.refill.timeOfFirstRefill ?: -1) * 1000000
        if (currentTime >= timeOfFirstRefillNanos) {
            return bandwidth.initialTokens
        }
        val guaranteedBase = max(0, bandwidth.capacity - bandwidth.refill.tokens)
        val nanosBeforeFirstRefill = timeOfFirstRefillNanos - currentTime
        return if (multiplyExact(nanosBeforeFirstRefill, bandwidth.refill.tokens) != null) {
            min(bandwidth.capacity, guaranteedBase + nanosBeforeFirstRefill
                * bandwidth.refill.tokens / bandwidth.refill.period)
        } else {
            min(bandwidth.capacity, guaranteedBase + (nanosBeforeFirstRefill.toDouble()
                * bandwidth.refill.tokens.toDouble() / bandwidth.refill.period.toDouble()).toLong())
        }
    }

    private fun getLastRefillTime(bandwidth: Bandwidth, currentTime: Nanos): Nanos {
        return if (bandwidth.refill.timeOfFirstRefill == null) {
            currentTime
        } else {
            bandwidth.refill.timeOfFirstRefill * 1000000 - bandwidth.refill.period
        }
    }

    private fun refill(index: Int, bandwidth: Bandwidth, currentTime: Nanos) {
        val previousRefillNanos = state[index * 3]
        if (currentTime <= previousRefillNanos) {
            return
        }
        val correctedCurrentNanos = if (bandwidth.refill.refillIntervally) {
            val incompleteIntervalCorrection = (currentTime - previousRefillNanos) % bandwidth.refill.period
            currentTime - incompleteIntervalCorrection
        } else {
            currentTime
        }
        if (correctedCurrentNanos <= previousRefillNanos) {
            return
        } else {
            state[index * 3] = correctedCurrentNanos
        }
        val capacity = bandwidth.capacity
        val refillPeriodNanos = bandwidth.refill.period
        val refillTokens = bandwidth.refill.tokens
        val currentSize = state[index * 3 + 1]
        if (currentSize >= capacity) {
            return
        }
        var durationSinceLastRefillNanos = correctedCurrentNanos - previousRefillNanos
        var newSize = currentSize
        if (durationSinceLastRefillNanos > refillPeriodNanos) {
            val elapsedPeriods = durationSinceLastRefillNanos / refillPeriodNanos
            val calculatedRefill = elapsedPeriods * refillTokens
            newSize += calculatedRefill
            if (newSize > capacity) {
                state[index * 3 + 1] = capacity
                state[index * 3 + 2] = 0
                return
            }
            if (newSize < currentSize) {
                state[index * 3 + 1] = capacity
                state[index * 3 + 2] = 0
                return
            }
            durationSinceLastRefillNanos %= refillPeriodNanos
        }
        val previousRoundingError = state[index * 3 + 2]
        val dividedWithoutError = multiplyExact(refillTokens, durationSinceLastRefillNanos) ?: Long.MAX_VALUE
        val divided = dividedWithoutError + previousRoundingError
        val roundingError = if (divided < 0 || dividedWithoutError == Long.MAX_VALUE) {
            val calculatedRefill = (durationSinceLastRefillNanos.toDouble() / refillPeriodNanos.toDouble()
                * refillTokens.toDouble()).toLong()
            newSize += calculatedRefill
            0
        } else {
            val calculatedRefill = divided / refillPeriodNanos
            if (calculatedRefill == 0L) {
                divided
            } else {
                newSize += calculatedRefill
                divided % refillPeriodNanos
            }
        }
        if (newSize >= capacity) {
            state[index * 3 + 1] = capacity
            state[index * 3 + 2] = 0
            return
        }
        if (newSize < currentSize) {
            state[index * 3 + 1] = capacity
            state[index * 3 + 2] = 0
            return
        }
        state[index * 3 + 1] = newSize
        state[index * 3 + 2] = roundingError
    }

    private fun getDelayToConsumption(index: Int,
                                      bandwidth: Bandwidth,
                                      tokens: Tokens,
                                      currentTime: Nanos,
                                      tokensLimitedByCapacity: Boolean): Nanos? {
        if (tokensLimitedByCapacity && tokens > bandwidth.capacity) {
            return null
        }
        val currentSize = state[index * 3 + 1]
        if (tokens <= currentSize) {
            return 0
        }
        val deficit = tokens - currentSize
        if (deficit <= 0) {
            return null
        }
        return if (bandwidth.refill.refillIntervally) {
            getDelayToConsumptionForIntervalBandwidth(index, bandwidth, deficit, currentTime)
        } else {
            getDelayToConsumptionForGreedyBandwidth(index, bandwidth, deficit)
        }
    }

    private fun getDelayToConsumptionForGreedyBandwidth(index: Int,
                                                        bandwidth: Bandwidth,
                                                        deficit: Tokens): Nanos {
        val periodNanos = bandwidth.refill.period
        val tokens = bandwidth.refill.tokens
        val multiplied = multiplyExact(periodNanos, deficit)
        return if (multiplied == null) {
            (deficit.toDouble() / tokens.toDouble() * periodNanos.toDouble()).toLong()
        } else {
            val correctionForPartiallyRefilledToken = state[index * 3 + 2]
            val correctedDivided = multiplied - correctionForPartiallyRefilledToken
            correctedDivided / tokens
        }
    }

    private fun getDelayToConsumptionForIntervalBandwidth(index: Int,
                                                          bandwidth: Bandwidth,
                                                          deficit: Tokens,
                                                          currentTime: Nanos): Nanos? {
        val periodNanos = bandwidth.refill.period
        val tokens = bandwidth.refill.tokens
        val previousRefillNanos = state[index * 3]
        val timeOfNextRefillNanos = previousRefillNanos + periodNanos
        val waitForNextRefillNanos = timeOfNextRefillNanos - currentTime
        if (deficit <= tokens) {
            return waitForNextRefillNanos
        }
        val correctedDeficit = deficit - tokens
        if (correctedDeficit < tokens) {
            return waitForNextRefillNanos + periodNanos
        }
        val deficitPeriods = if (correctedDeficit % tokens == 0L) {
            correctedDeficit / tokens
        } else {
            correctedDeficit / tokens + 1
        }
        val deficitNanos = multiplyExact(deficitPeriods, periodNanos) ?: return null
        val correctedDeficitNanos = deficitNanos + waitForNextRefillNanos
        return if (correctedDeficitNanos < 0) {
            null
        } else {
            correctedDeficitNanos
        }
    }

    private fun multiplyExact(left: Long, right: Long): Long? {
        val result = left * right
        val absLeft = abs(left)
        val absRight = abs(right)
        if ((absLeft or absRight) ushr 31 != 0L) {
            if (((right != 0L) && (result / right != left)) || (left == Long.MIN_VALUE && right == -1L)) {
                return null
            }
        }
        return result
    }

}
