package ru.yandex.webmaster3.storage.util;

import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.concurrent.TimeUnit;

import com.google.common.io.CountingInputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * @author aherman
 */
public class ProgressLogInputStream extends FilterInputStream {
    private static final Logger log = LoggerFactory.getLogger(ProgressLogInputStream.class);
    private static final long LOG_PERIOD_MILLIS = TimeUnit.SECONDS.toMillis(10);

    private long lastBytesCount = 0;
    private long lastTimeNs = 0;
    private final long startTimeNs;

    private final String streamName;

    public ProgressLogInputStream(InputStream parent, String streamName) {
        this(new CountingInputStream(parent), streamName);
    }

    public ProgressLogInputStream(CountingInputStream parent, String streamName) {
        super(parent);
        this.startTimeNs = System.nanoTime();
        this.lastTimeNs = this.startTimeNs;
        this.streamName = streamName;
    }

    @Override
    public int read() throws IOException {
        int value = in.read();
        reportProgress(value == -1);
        return value;
    }

    @Override
    public int read(byte[] b) throws IOException {
        int value = in.read(b);
        reportProgress(value == -1);
        return value;
    }

    @Override
    public int read(byte[] b, int off, int len) throws IOException {
        int value = in.read(b, off, len);
        reportProgress(value == -1);
        return value;
    }

    @Override
    public void close() throws IOException {
        reportProgress(true);
        in.close();
    }

    private void reportProgress(boolean last) {
        long bytesCount = ((CountingInputStream)in).getCount();
        long nowNs = System.nanoTime();

        if (last) {
            long timeDiff = TimeUnit.NANOSECONDS.toMillis(nowNs - startTimeNs);
            logProgress(bytesCount, bytesCount, timeDiff);
        } else {
            long bytesDiff = bytesCount - lastBytesCount;
            long timeDiff = TimeUnit.NANOSECONDS.toMillis(nowNs - lastTimeNs);
            if (timeDiff > LOG_PERIOD_MILLIS) {
                timeDiff = Math.max(timeDiff, 1L);
                logProgress(bytesCount, bytesDiff, timeDiff);
                lastBytesCount = bytesCount;
                lastTimeNs = nowNs;
            }
        }
    }

    private void logProgress(long bytesCount, long bytesDiff, long timeDiff) {
        log.info("{}: {} {}",
                streamName,
                prettyPrintBytes(bytesCount, ""),
                prettyPrintBytes(bytesDiff * 1000.0 / timeDiff, "/s")
        );
    }

    private String prettyPrintBytes(double count, String suffix) {
        if (count < 1024L * 1024L * 1024L) {
            return String.format("%2.2fMb%s", count * 1.0 / (1024L * 1024L), suffix);
        } else {
            return String.format("%2.2fGb%s", count * 1.0 / (1024L * 1024L * 1024L), suffix);
        }
    }
}
