package ru.yandex.travel.commons.rate;

import com.google.common.base.Preconditions;
import lombok.extern.java.Log;

import java.time.Duration;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicLongArray;

@Log
public class Throttler {
    private final long rateLimit;
    private final long semaphoreLimit;
    private final long bucketMs;
    private final long windowMs;
    private final AtomicLongArray rates;
    private final AtomicLong semaphore;

    public enum EDecision {
        PASS,
        RATE_LIMIT,
        CONCURRENCY_LIMIT,
    }

    public Throttler(long rateLimit, long semaphoreLimit, Duration bucket, Duration window) {
        Preconditions.checkArgument(rateLimit > 0);
        Preconditions.checkArgument(semaphoreLimit > 0);
        this.rateLimit = rateLimit;
        this.semaphoreLimit = semaphoreLimit;
        bucketMs = bucket.toMillis();
        windowMs = window.toMillis();
        int length = Math.toIntExact((windowMs + bucketMs - 1) / bucketMs);
        rates = new AtomicLongArray(length);
        semaphore = new AtomicLong();
    }

    private long makePair(int bucket, int value) {
        return ((long)bucket << 32) | (long) value;
    }

    private int getBucketFromPair(long value) {
        return (int) (value >> 32);
    }

    private int getValueFromPair(long pair) {
        return (int)(pair);
    }

    private boolean updateRate(long nowMs, int delta) {
        int bucket = (int)(nowMs / bucketMs);
        while (true) {
            long previous = rates.get(bucket % rates.length());
            long next;
            int previousBucket = getBucketFromPair(previous);
            int previousValue = getValueFromPair(previous);
            if (previousBucket < bucket) {
                next = makePair(bucket, delta);
            } else if (previousBucket == bucket) {
                if (previousValue + delta > rateLimit) {
                    return false;
                }
                next = makePair(bucket, previousValue + delta);
            } else {
                return false;
            }
            if (rates.compareAndSet(bucket % rates.length(), previous, next)) {
                Preconditions.checkState(getValueFromPair(next) >= 0, "rate limit is negative");
                Preconditions.checkState(getValueFromPair(next) <= rateLimit, "rate limit exceeded");
                return true;
            }
        }
    }

    private boolean updateSemaphore(int delta) {
        long attempt = 0;
        while (true) {
            long previous = semaphore.get();
            if (previous + delta > semaphoreLimit) {
                return false;
            }
            long next = previous + delta;
            if (semaphore.compareAndSet(previous, next)) {
                Preconditions.checkState(next >= 0, "semaphore is negative");
                Preconditions.checkState(next <= semaphoreLimit, "semaphore limit exceeded");
                return true;
            }
            attempt++;
            if (attempt > 100) {
                log.warning("Unable to update semaphore in {} attempts, live-lock expected");
            }
        }
    }

    public EDecision acquire(long nowMs) {
        if (!updateSemaphore(1)) {
            return EDecision.CONCURRENCY_LIMIT;
        }
        if (!updateRate(nowMs, 1)) {
            while (true) {
                if (updateSemaphore(-1)) {
                    return EDecision.RATE_LIMIT;
                }
            }
        }
        return EDecision.PASS;
    }

    public void release() {
        Preconditions.checkState(updateSemaphore(-1), "cannot release semaphore");
    }

    public long getSemaphoreValue() {
        return semaphore.get();
    }

    public long getSemaphoreLimit() {
        return semaphoreLimit;
    }
}
