package ru.yandex.msearch.jobs;

import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;

import java.lang.ref.WeakReference;

import java.util.concurrent.atomic.AtomicLong;

public class DownloadLimiter {
    private static final int MB_BYTES = 1024 * 1024;
    private static final int SEC = 1000;
    private static final int LIMITER_RESOLUTION = 100;
    private static final int RATE_DIVISOR = SEC / LIMITER_RESOLUTION;

    private long bytesLimitPerTick;
    private final Object lock = new Object();
    private final AtomicLong bytesCountDown = new AtomicLong(0);
    private final AtomicLong bytesCopied = new AtomicLong(0);
    private long prevRateCheck = 0;
    private long lastCountDownResetTime = 0;

    public DownloadLimiter(final int rateLimitMb) {
        this.bytesLimitPerTick = (rateLimitMb * MB_BYTES) / RATE_DIVISOR;
        bytesCountDown.set(bytesLimitPerTick);
        TimerHelper timer = new TimerHelper(this);
        timer.setDaemon(true);
        timer.start();
    }

    public void rateLimitMb(final int rateLimitMb) {
        this.bytesLimitPerTick = (rateLimitMb * MB_BYTES) / RATE_DIVISOR;
        bytesCountDown.set(bytesLimitPerTick);
        System.err.println("BytesLimitPerTick: " + bytesLimitPerTick);
    }

    public InputStream wrapInputStream(final InputStream stream) {
        return new RateLimitInputStream(stream, this);
    }

    public double currentRate() {
        final long currentTime = System.currentTimeMillis();
        final double bytesCopied = this.bytesCopied.getAndSet(0);
        final double elapsedTime =
            Math.max(
                0.000001,
                (double) (currentTime - prevRateCheck) / (double)SEC);
        prevRateCheck = currentTime;
        final double rate = bytesCopied / elapsedTime;
        return rate / MB_BYTES;
    }

    protected void rateLimit(final int len) {
        bytesCopied.addAndGet(len);
        final long bytesAllowed = bytesCountDown.addAndGet(-len);
        if (bytesAllowed < 0) {
            synchronized (lock) {
                try {
                    lock.wait();
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
            }
        }
    }

    public void timerTick() {
        if (bytesCountDown.get() < 0) {
            if (bytesCountDown.addAndGet(bytesLimitPerTick) > 0) {
                notifyWaiters();
            }
        } else {
            bytesCountDown.set(bytesLimitPerTick);
            notifyWaiters();
        }
    }

    private void notifyWaiters() {
        synchronized (lock) {
            lock.notifyAll();
        }
    }

    private static class RateLimitInputStream extends FilterInputStream {
        private final DownloadLimiter limiter;
        public RateLimitInputStream(
            final InputStream in,
            final DownloadLimiter limiter)
        {
            super(in);
            this.limiter = limiter;
        }

        @Override
        public int read() throws IOException {
            int c = super.read();
            if (c >= 0) {
                limiter.rateLimit(1);
            }
            return c;
        }

        @Override
        public int read(final byte[] b) throws IOException {
            int len = super.read(b);
            if (len > 0) {
                limiter.rateLimit(len);
            }
            return len;
        }

        @Override
        public int read(final byte[] b, final int off, final int len)
            throws IOException
        {
            int red = super.read(b, off, len);
            if (red > 0) {
                limiter.rateLimit(red);
            }
            return red;
        }
    }

    private static class TimerHelper extends Thread {
        private final WeakReference<DownloadLimiter> limiterRef;
        public TimerHelper(final DownloadLimiter limiter) {
            super("TimerHelper");
            this.limiterRef = new WeakReference<>(limiter);
        }

        @Override
        public void run() {
            while (true) {
                DownloadLimiter limiter = limiterRef.get();
                if (limiter == null) {
                    return;
                }
                try {
                    Thread.sleep(LIMITER_RESOLUTION);
                    limiter.timerTick();
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    break;
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }
    }
}
