package ru.yandex.http.server.sync.util;

import java.io.IOException;
import java.io.InterruptedIOException;
import java.io.OutputStream;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;

import ru.yandex.util.timesource.TimeSource;

public class SocketChannelOutputStream extends OutputStream {
    private final Socket socket;
    private final SocketChannel channel;
    private volatile int timeout;
    private ByteBuffer bb = null;
    private byte[] b = null;
    private byte[] b1 = null;

    public SocketChannelOutputStream(
        final Socket socket,
        final SocketChannel channel,
        final int timeout)
    {
        this.socket = socket;
        this.channel = channel;
        this.timeout = timeout;
    }

    public void setSoTimeout(final int timeout) {
        this.timeout = timeout;
    }

    public void writeFully(final ByteBuffer bb) throws IOException {
        int remaining = bb.remaining();
        if (remaining == 0) {
            return;
        }
        channel.configureBlocking(false);
        try {
            int written = channel.write(bb);
            if (written < remaining) {
                long lastActivity = TimeSource.INSTANCE.currentTimeMillis();
                if (written == 0) {
                    try {
                        Thread.sleep(1);
                    } catch (InterruptedException e) {
                        Thread.currentThread().interrupt();
                        throw new InterruptedIOException();
                    }
                }
                while (true) {
                    if (channel.write(bb) > 0) {
                        if (!bb.hasRemaining()) {
                            return;
                        }
                        lastActivity = TimeSource.INSTANCE.currentTimeMillis();
                    } else {
                        long now = TimeSource.INSTANCE.currentTimeMillis();
                        if (now - lastActivity > timeout && timeout > 0) {
                            throw new SocketTimeoutException();
                        }
                        try {
                            Thread.sleep(1);
                        } catch (InterruptedException e) {
                            Thread.currentThread().interrupt();
                            InterruptedIOException ex =
                                new InterruptedIOException();
                            ex.bytesTransferred = remaining - bb.remaining();
                            throw ex;
                        }
                    }
                }
            }
        } finally {
            try {
                channel.configureBlocking(true);
            } catch (Throwable t) {
                // Ignore
            }
        }
    }

    public void write(final ByteBuffer bb)
        throws IOException
    {
        int remaining = bb.remaining();
        if (remaining == 0) {
            return;
        }
        channel.configureBlocking(false);
        try {
            int written = channel.write(bb);
            if (written == 0) {
                long lastActivity = TimeSource.INSTANCE.currentTimeMillis();
                try {
                    Thread.sleep(1);
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    throw new InterruptedIOException();
                }
                while (channel.write(bb) == 0) {
                    long now = TimeSource.INSTANCE.currentTimeMillis();
                    if (now - lastActivity > timeout && timeout > 0) {
                        throw new SocketTimeoutException();
                    }
                    try {
                        Thread.sleep(1);
                    } catch (InterruptedException e) {
                        Thread.currentThread().interrupt();
                        InterruptedIOException ex =
                            new InterruptedIOException();
                        ex.bytesTransferred = remaining - bb.remaining();
                        throw ex;
                    }
                }
            }
        } finally {
            try {
                channel.configureBlocking(true);
            } catch (Throwable t) {
                // Ignore
            }
        }
    }

    @Override
    public void write(final byte[] b) throws IOException {
        write(b, 0, b.length);
    }

    @Override
    public void write(final byte[] b, final int off, final int len)
        throws IOException
    {
        int limit = off + len;
        if (off < 0 || len < 0 || limit < 0 || limit > b.length) {
            throw new IndexOutOfBoundsException();
        }
        if (len != 0) {
            if (b != this.b) {
                bb = ByteBuffer.wrap(b);
                this.b = b;
            }
            bb.limit(off + len);
            bb.position(off);
            writeFully(bb);
        }
    }

    @Override
    public void write(final int b) throws IOException {
        if (b1 == null) {
            b1 = new byte[1];
        }
        b1[0] = (byte) b;
        write(b1, 0, 1);
    }

    @Override
    public void close() throws IOException {
        socket.close();
    }
}

