package ru.yandex.intranet.d.util.bucket

import java.util.concurrent.atomic.AtomicReference
import kotlin.math.min

class Bucket(
    configuration: BucketConfiguration,
    private val nanosSupplier: () -> Long) {

    private val stateRef: AtomicReference<BucketState>

    init {
        stateRef = AtomicReference(BucketState(configuration, nanosSupplier()))
    }

    fun tryConsume(tokens: Tokens): Boolean {
        if (tokens <= 0L) {
            throw IllegalArgumentException("Tokens must be positive")
        }
        var previousState = stateRef.get()
        val newState = previousState.copy()
        val currentTimeNanos = nanosSupplier()
        while (true) {
            newState.refillAll(currentTimeNanos)
            val availableToConsume = newState.getAvailableTokens()
            if (tokens > availableToConsume) {
                return false
            }
            newState.consume(tokens)
            if (stateRef.compareAndSet(previousState, newState)) {
                return true
            } else {
                previousState = stateRef.get()
                newState.copyFrom(previousState)
            }
        }
    }

    fun forceConsume(tokens: Tokens): Nanos? {
        if (tokens <= 0L) {
            throw IllegalArgumentException("Tokens must be positive")
        }
        var previousState = stateRef.get()
        val newState = previousState.copy()
        val currentTime = nanosSupplier()
        while (true) {
            newState.refillAll(currentTime)
            val nanosToCloseDeficit = newState.getDelayToConsumption(tokens, currentTime, false) ?: return null
            newState.consume(tokens)
            if (stateRef.compareAndSet(previousState, newState)) {
                return nanosToCloseDeficit
            } else {
                previousState = stateRef.get()
                newState.copyFrom(previousState)
            }
        }
    }

    fun getAvailableTokens(): Tokens {
        val currentTime = nanosSupplier()
        val snapshot = stateRef.get().copy()
        snapshot.refillAll(currentTime)
        return snapshot.getAvailableTokens()
    }

    fun reset() {
        var previousState = stateRef.get()
        val newState = previousState.copy()
        val currentTime = nanosSupplier()
        while (true) {
            newState.refillAll(currentTime)
            newState.reset()
            if (stateRef.compareAndSet(previousState, newState)) {
                return
            } else {
                previousState = stateRef.get()
                newState.copyFrom(previousState)
            }
        }
    }

    fun consumeAsMuchAsPossible(): Tokens {
        return consumeAsMuchAsPossible(Long.MAX_VALUE)
    }

    fun consumeAsMuchAsPossible(limit: Tokens): Tokens {
        var previousState = stateRef.get()
        val newState = previousState.copy()
        val currentTime = nanosSupplier()
        while (true) {
            newState.refillAll(currentTime)
            val availableToConsume = newState.getAvailableTokens()
            val toConsume = min(limit, availableToConsume)
            if (toConsume == 0L) {
                return 0
            }
            newState.consume(toConsume)
            if (stateRef.compareAndSet(previousState, newState)) {
                return toConsume
            } else {
                previousState = stateRef.get()
                newState.copyFrom(previousState)
            }
        }
    }

    fun addTokens(tokens: Tokens) {
        if (tokens <= 0L) {
            throw IllegalArgumentException("Tokens must be positive")
        }
        var previousState = stateRef.get()
        val newState = previousState.copy()
        val currentTime = nanosSupplier()
        while (true) {
            newState.refillAll(currentTime)
            newState.addTokens(tokens)
            if (stateRef.compareAndSet(previousState, newState)) {
                return
            } else {
                previousState = stateRef.get()
                newState.copyFrom(previousState)
            }
        }
    }

    fun forceAddTokens(tokens: Tokens) {
        if (tokens <= 0L) {
            throw IllegalArgumentException("Tokens must be positive")
        }
        var previousState = stateRef.get()
        val newState = previousState.copy()
        val currentTime = nanosSupplier()
        while (true) {
            newState.refillAll(currentTime)
            newState.forceAddTokens(tokens)
            if (stateRef.compareAndSet(previousState, newState)) {
                return
            } else {
                previousState = stateRef.get()
                newState.copyFrom(previousState)
            }
        }
    }
}

typealias Nanos = Long
typealias Millis = Long
typealias Tokens = Long
