package ru.yandex.webmaster3.storage.util;

import java.io.IOException;
import java.io.OutputStream;
import java.util.concurrent.TimeUnit;

import com.google.common.io.CountingOutputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * @author aherman
 */
public class ProgressLogOutputStream extends OutputStream {
    private static final Logger log = LoggerFactory.getLogger(ProgressLogOutputStream.class);

    private static final int REPORT_BYTES = 50 * 1024 * 1024;
    private final long startTime;
    private volatile long lastBytesCount = 0;
    private volatile long lastTimeNs = 0;

    private final String streamName;
    private final CountingOutputStream parent;

    public ProgressLogOutputStream(OutputStream parent) {
        this(parent, "UNKNOWN");
    }

    public ProgressLogOutputStream(OutputStream parent, String streamName) {
        this.parent = new CountingOutputStream(parent);
        this.startTime = System.nanoTime();
        this.lastTimeNs = startTime;
        this.streamName = streamName;
    }

    @Override
    public void write(int b) throws IOException {
        parent.write(b);
        reportProgress();
    }

    @Override
    public void write(byte[] b, int off, int len) throws IOException {
        parent.write(b, off, len);
        reportProgress();
    }

    @Override
    public void close() throws IOException {
        parent.close();
        reportProgressUnchecked();
    }

    @Override
    public void write(byte[] b) throws IOException {
        parent.write(b);
        reportProgress();
    }

    @Override
    public void flush() throws IOException {
        parent.flush();
    }

    private void reportProgress() {
        long bytesCount = parent.getCount();
        long bytesDiff = bytesCount - lastBytesCount;
        if (bytesDiff > REPORT_BYTES) {
            long nowNs = System.nanoTime();
            long timeDiff = nowNs - lastTimeNs;

            log.info("Stream stat[{}]: {}/{}",
                    streamName,
                    prettyPrintBytes(bytesCount, ""),
                    prettyPrintBytes(bytesDiff * 1000.0 / TimeUnit.NANOSECONDS.toMillis(timeDiff), "/s")
            );

            lastBytesCount = bytesCount;
            lastTimeNs = nowNs;
        }
    }

    private void reportProgressUnchecked() {
        long bytesCount = parent.getCount();
        long diffS = Math.max(TimeUnit.NANOSECONDS.toSeconds(System.nanoTime() - startTime), 1);
        log.info("Stream total stat[{}]: {} {}",
                streamName,
                prettyPrintBytes(bytesCount, ""),
                prettyPrintBytes(bytesCount / diffS, "/s")
        );
    }

    private String prettyPrintBytes(double count, String suffix) {
        if (count < 1024L * 1024L * 1024L) {
            return String.format("%2.2f Mb%s", count * 1.0 / (1024 * 1024), suffix);
        } else {
            return String.format("%2.2f Gb%s", count * 1.0 / (1024 * 1024 * 1024), suffix);
        }
    }
}
