package ru.yandex.chemodan.videostreaming.framework.util;

import java.io.BufferedOutputStream;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Arrays;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.LockSupport;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import java.util.function.IntSupplier;

import javax.annotation.Nonnull;

import org.joda.time.Duration;
import org.joda.time.Instant;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.misc.ExceptionUtils;
import ru.yandex.misc.io.InputStreamSource;
import ru.yandex.misc.io.InputStreamX;
import ru.yandex.misc.io.InputStreamXUtils;
import ru.yandex.misc.io.IoUtils;

/**
 * @author Dmitriy Amelin (lemeh)
 */
public class AsyncCacheableInputStreamSource extends InputStreamSource {
    private static final Duration TIMEOUT = Duration.standardMinutes(1);

    private static final int PARK_WAIT_DURATION = 100 * 1000 * 1000;

    private final Lock receiveLock = new ReentrantLock();

    private volatile byte[] data = new byte[0];

    private volatile int count = 0;

    private volatile ListF<Thread> readThreads = Cf.x(new CopyOnWriteArrayList<>());

    private volatile boolean receiveComplete = false;

    private volatile RuntimeException exception;

    private final int initialSize;

    @SuppressWarnings("unused")
    public AsyncCacheableInputStreamSource() {
        this(32);
    }

    public AsyncCacheableInputStreamSource(int size) {
        if (size < 0) {
            throw new IllegalArgumentException("Negative initial size: " + size);
        }

        initialSize = size;
    }

    @SuppressWarnings("unused")
    public void receiveFrom(InputStream in) {
        receiveFrom(out -> IoUtils.copy(in, out));
    }

    public void receiveFrom(Consumer<OutputStream> receiver) {
        if (receiveLock.tryLock()) {
            try (OutputStream out = getBufferedOutputStream()) {
                receiver.accept(out);
            } catch (IOException | RuntimeException ex) {
                if (this.exception == null) {
                    this.exception = ExceptionUtils.translate(ex);
                }
                throw ExceptionUtils.translate(ex);
            } finally {
                receiveLock.unlock();
            }
        } else {
            throw new IllegalStateException("Only one writer is allowed");
        }
    }

    private BufferedOutputStream getBufferedOutputStream() {
        return new BufferedOutputStream(getOutputStream());
    }

    public AsyncOutputStream getOutputStream() {
        if (count != 0) {
            throw new IllegalStateException("Data was already received");
        }

        data = new byte[initialSize];
        return new AsyncOutputStream();
    }

    public void setExternalException(RuntimeException exception) {
        this.exception = exception;
        completeReceive();
    }

    private void unparkReadThreads() {
        for(Thread readThread : readThreads) {
            LockSupport.unpark(readThread);
        }
    }

    @Override
    public InputStream getInput() {
        return getInput(0, receiveComplete ? count : Integer.MAX_VALUE);
    }

    public InputStreamX getInputStreamX(int offset, int length) {
        return InputStreamXUtils.wrap(getInput(offset, length));
    }

    public InputStream getInput(int offset, int length) {
        return !receiveComplete // read barrier
                ? new AsyncInputStream(offset, length)
                : new ByteArrayInputStream(data, offset, length);
    }

    private void completeReceive() {
        receiveComplete = true; // write barrier
        unparkReadThreads();
        readThreads = Cf.list();
    }

    private class AsyncInputStream extends InputStream {
        int offset;

        final int maxCount;

        AsyncInputStream(int offset, int length) {
            if (!receiveComplete) {
                readThreads.add(Thread.currentThread());
            }
            this.offset = offset;
            this.maxCount = offset + length;
        }

        @Override
        public int read() {
            return waitUntilAvailableAndRead(() -> data[offset++] & 0xFF);
        }

        @Override
        public int read(@Nonnull byte b[], int off, int len) {
            return waitUntilAvailableAndRead(() -> {
                int resultLen = Math.min(len, getActualCount() - offset);
                System.arraycopy(data, offset, b, off, resultLen);
                offset += resultLen;
                return resultLen;
            });
        }

        int getActualCount() {
            return Math.min(maxCount, count);
        }

        int waitUntilAvailableAndRead(IntSupplier supplier) {
            waitUntilAvailable();
            if (exception != null) {
                throw exception;
            }
            return !noMoreData() ? supplier.getAsInt() : -1;
        }

        boolean noMoreData() {
            return (offset == maxCount) || (receiveComplete && offset >= count);
        }

        void waitUntilAvailable() {
            Instant start = Instant.now();
            while (offset < maxCount && offset >= count && !receiveComplete) { // count, receiveComplete - read barrier
                Duration duration = new Duration(start, Instant.now());
                if (duration.isLongerThan(TIMEOUT)) {
                    throw new IllegalStateException("No data for " + duration);
                }

                // This doesn't affect latency because producer thread calls unpark when new data is available.
                LockSupport.parkNanos(PARK_WAIT_DURATION);
            }
        }
    }

    private class AsyncOutputStream extends OutputStream {
        @Override
        public void write(int b) {
            write(() -> data[count] = (byte) b, 1);
        }

        @Override
        public void write(@Nonnull byte b[], int off, int len) {
            write(() -> System.arraycopy(b, off, data, count, len), len);
        }

        @Override
        public void close() {
            completeReceive();
        }

        void write(Runnable runnable, int len) {
            if (len == 0) {
                return;
            }

            ensureExtraCapacity(len);
            runnable.run();
            count += len; // write barrier
            unparkReadThreads();
        }

        void ensureExtraCapacity(int extraCapacity) {
            int neededCapacity = count + extraCapacity;
            if (neededCapacity > data.length) {
                grow(neededCapacity);
            }
        }

        void grow(int capacity) {
            data = Arrays.copyOf(data, Math.max(data.length << 1, capacity));
        }
    }
}
