package ru.yandex.http.server.sync.util;

import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;

import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;

import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;

import ru.yandex.function.GenericAutoCloseable;
import ru.yandex.util.timesource.TimeSource;

public class SSLSocketChannelIO implements GenericAutoCloseable<IOException> {
    private static final DateTimeFormatter DATE_FORMATTER =
        DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss.SSS");
    private static final ByteBuffer EMPTY_BUF = ByteBuffer.allocate(0);
    private static final int SKIP_BUFFER_SIZE = 2048;
    private static final byte[] SKIP_BUFFER = new byte[SKIP_BUFFER_SIZE];
    private static final boolean DEBUG_HTTPS =
        "https".equalsIgnoreCase(
            System.getProperty("ru.yandex.http.server.sync.debug"));
    private static final String OUTPUT_CLOSED = "Output closed";
    private static final int INT_MASK = 0xff;

    private final InputStream inputStream = new SSLInputStream();
    private final OutputStream outputStream = new SSLOutputStream();
    private final ThreadSafeSocketAdaptor socket;
    private final SSLEngine engine;
    private final InputStream socketInput;
    private final SocketChannelOutputStream socketOutput;
    private final ByteBuffer inEncrypted;
    private final ByteBuffer outEncrypted;
    private final ByteBuffer inPlain;
    private final ByteBuffer outPlain;
    private volatile boolean initialized = false;
    private volatile boolean inboundClosed = false;
    private volatile boolean outboundClosed = false;

    public SSLSocketChannelIO(
        final ThreadSafeSocketAdaptor socket,
        final SSLEngine engine)
        throws IOException
    {
        this.socket = socket;
        this.engine = engine;
        socketInput = socket.getInputStream();
        socketOutput = socket.getOutputStream();

        int packedBufferSize = engine.getSession().getPacketBufferSize();
        inEncrypted = ByteBuffer.allocate(packedBufferSize);
        inEncrypted.limit(0);
        outEncrypted = ByteBuffer.allocate(packedBufferSize);
        outEncrypted.limit(0);

        int applicationBufferSize =
            engine.getSession().getApplicationBufferSize();
        inPlain = ByteBuffer.allocate(applicationBufferSize);
        inPlain.limit(0);
        outPlain = ByteBuffer.allocate(applicationBufferSize);
        outPlain.limit(0);
    }

    public boolean checkHandshake() throws IOException {
        if (!initialized) {
            synchronized (this) {
                if (!initialized) {
                    initialized = true;
                    handshake();
                    return true;
                }
            }
        }
        return false;
    }

    public InputStream inputStream() throws IOException {
        checkHandshake();
        return inputStream;
    }

    public OutputStream outputStream() throws IOException {
        checkHandshake();
        return outputStream;
    }

    public boolean hasBufferedData() {
        return inPlain.hasRemaining() || inEncrypted.hasRemaining();
    }

    private void info(final SSLEngineResult result, final String message) {
        System.err.println(
            DATE_FORMATTER.print(TimeSource.INSTANCE.currentTimeMillis()) + ' '
            + Thread.currentThread().getName() + '@' + socket
            + ':' + ' ' + message);
        if (result != null) {
            System.err.println("SSLEngine status: " + result.getStatus());
            System.err.println(
                "Handshake status: "
                + result.getHandshakeStatus());
        }
        System.err.println(
            "Current handshake status: " + engine.getHandshakeStatus());
        System.err.println("Buffers status:");
        System.err.println("\tinPlain:\t" + inPlain);
        System.err.println("\tinEncrypted:\t" + inEncrypted);
        System.err.println("\toutPlain:\t" + outPlain);
        System.err.println("\toutEncrypted:\t" + outEncrypted);
        System.err.println("\tEMPTY_BUF:\t" + EMPTY_BUF);
        System.err.println("inboundClosed: " + inboundClosed);
        System.err.println("outboundClosed: " + outboundClosed);
        System.err.println();
    }

    @Override
    public void close() throws IOException {
        if (engine.isInboundDone()) {
            engine.closeInbound();
        }
        inboundClosed = true;
        engine.closeOutbound();
        outputStream.flush();
        outboundClosed = true;
    }

    @SuppressWarnings("ByteBufferBackingArray")
    private int read(final ByteBuffer bb) throws IOException {
        int position = bb.position();
        int read = socketInput.read(bb.array(), position, bb.remaining());
        if (read != -1) {
            bb.position(position + read);
        }
        return read;
    }

    // sun.security.pkcs11.wrapper.PKCS11Exception is re-thrown as plain
    // RuntimeException in sun.security.ssl.Handshaker#checkThrown
    private static SSLException unwrapException(final RuntimeException e) {
        Throwable cause = e.getCause();
        if (cause == null) {
            cause = e;
        }
        return new SSLException(cause);
    }

    private SSLEngineResult wrap(final ByteBuffer src, final ByteBuffer dst)
        throws SSLException
    {
        try {
            // Access already synchronized by SSLEngineImpl
            return engine.wrap(src, dst);
        } catch (RuntimeException e) {
            throw unwrapException(e);
        }
    }

    private SSLEngineResult unwrap(final ByteBuffer src, final ByteBuffer dst)
        throws SSLException
    {
        try {
            // Access already synchronized by SSLEngineImpl
            return engine.unwrap(src, dst);
        } catch (RuntimeException e) {
            throw unwrapException(e);
        }
    }

    private void handshakeWrap() throws IOException {
        outEncrypted.clear();
        if (DEBUG_HTTPS) {
            info(null, "Begin handshake wrap");
        }
        SSLEngineResult result = wrap(EMPTY_BUF, outEncrypted);
        if (DEBUG_HTTPS) {
            info(result, "Hanshake wrap result");
        }
        switch (result.getStatus()) {
            case OK:
                outEncrypted.flip();
                socketOutput.writeFully(outEncrypted);
                outEncrypted.clear();
                if (DEBUG_HTTPS) {
                    info(null, "Output written completedly");
                }
                break;
            case BUFFER_OVERFLOW:
                throw new SSLException("Why this buffer is not enough?");
            case CLOSED:
                throw new IOException("Socket channel closed");
            default: // BUFFER_UNDERFLOW
                throw new SSLException("Should not happen");
        }
    }

    private void handshakeUnwrap() throws IOException {
        if (!inEncrypted.hasRemaining()) {
            inEncrypted.clear();
            if (read(inEncrypted) == -1) {
                throw new EOFException();
            }
            inEncrypted.flip();
        }
        inPlain.compact();
        if (DEBUG_HTTPS) {
            info(null, "Unwrap data read");
        }
        SSLEngineResult result = unwrap(inEncrypted, inPlain);
        inEncrypted.compact();
        inPlain.flip();
        if (DEBUG_HTTPS) {
            info(result, "Hanshake unwrap result");
        }
        switch (result.getStatus()) {
            case OK:
                inEncrypted.flip();
                break;
            case BUFFER_UNDERFLOW:
                // try read more data
                if (read(inEncrypted) == -1) {
                    throw new EOFException();
                }
                inEncrypted.flip();
                if (DEBUG_HTTPS) {
                    info(null, "Additional data read for unwrap");
                }
                break;
            case CLOSED:
                throw new IOException("SSL session closed on unwrap");
            default: // BUFFER_OVERFLOW
                throw new SSLException(
                    "Application buffer overflow on unwrap");
        }
    }

    private void runDelegatedTask() throws SSLException {
        try {
            // Access already synchronized by SSLEngineImpl.
            Runnable r = engine.getDelegatedTask();
            while (r != null) {
                if (DEBUG_HTTPS) {
                    info(null, "Running delegated task: " + r);
                    long start = TimeSource.INSTANCE.currentTimeMillis();
                    r.run();
                    long end = TimeSource.INSTANCE.currentTimeMillis();
                    info(null, "Time taken: " + (end - start));
                } else {
                    r.run();
                }
                r = engine.getDelegatedTask();
            }
        } catch (RuntimeException e) {
            throw unwrapException(e);
        }
    }

    private void handshake() throws IOException {
        engine.setUseClientMode(false);
        engine.beginHandshake();
        while (true) {
            switch (engine.getHandshakeStatus()) {
                case NEED_WRAP:
                    handshakeWrap();
                    break;
                case NEED_UNWRAP:
                    handshakeUnwrap();
                    break;
                case NEED_TASK:
                    runDelegatedTask();
                    break;
                default: // NOT_HANDSHAKING, FINISHED
                    return;
            }
        }
    }

    private boolean fillInputBuffer() throws IOException {
        if (DEBUG_HTTPS) {
            info(null, "Trying to fill input buffer");
        }
        while (!inboundClosed && !inPlain.hasRemaining()) {
            if (!inEncrypted.hasRemaining()) {
                inEncrypted.clear();
                if (read(inEncrypted) == -1) {
                    throw new EOFException();
                }
                inEncrypted.flip();
            }
            if (DEBUG_HTTPS) {
                info(null, "Trying to decode input buffer");
            }
            inPlain.clear();
            SSLEngineResult result = unwrap(inEncrypted, inPlain);
            inEncrypted.compact();
            inPlain.flip();
            if (DEBUG_HTTPS) {
                info(result, "InputBuffer decoded");
            }
            switch (result.getStatus()) {
                case OK:
                    inEncrypted.flip();
                    break;
                case BUFFER_UNDERFLOW:
                    // try read more data
                    if (read(inEncrypted) == -1) {
                        throw new EOFException();
                    }
                    inEncrypted.flip();
                    if (DEBUG_HTTPS) {
                        info(null, "Additional data read for decode");
                    }
                    break;
                case CLOSED:
                    throw new IOException("SSL session closed");
                default: // Buffer overflow
                    throw new SSLException(
                        "Application buffer overflow");
            }
        }
        return inPlain.hasRemaining();
    }

    private void writeOutputBuffer() throws IOException {
        if (DEBUG_HTTPS) {
            info(null, "Trying to write output buffer");
        }
        while (!outboundClosed && outPlain.position() > 0) {
            outPlain.flip();
            SSLEngineResult result = wrap(outPlain, outEncrypted);
            outPlain.compact();
            if (DEBUG_HTTPS) {
                info(result, "Output buffer wrap result");
            }
            switch (result.getStatus()) {
                case OK:
                case BUFFER_OVERFLOW:
                    outEncrypted.flip();
                    socketOutput.write(outEncrypted);
                    outEncrypted.compact();
                    break;
                case CLOSED:
                    flushOutputBuffer();
                    outboundClosed = true;
                    break;
                default: // BUFFER_UNDERFLOW
                    throw new SSLException("Buffer underflow");
            }
        }
        outPlain.clear();
    }

    private void flushOutputBuffer() throws IOException {
        if (!outboundClosed) {
            outEncrypted.flip();
            socketOutput.writeFully(outEncrypted);
            outEncrypted.clear();
        }
    }

    private class SSLInputStream extends InputStream {
        @Override
        public int available() {
            return inPlain.remaining();
        }

        @Override
        public void close() throws IOException {
            SSLSocketChannelIO.this.close();
        }

        @Override
        public void mark(final int readLimit) {
        }

        @Override
        public boolean markSupported() {
            return false;
        }

        @Override
        public int read() throws IOException {
            if (fillInputBuffer()) {
                return inPlain.get() & INT_MASK;
            } else {
                return -1;
            }
        }

        @Override
        public int read(final byte[] buf) throws IOException {
            return read(buf, 0, buf.length);
        }

        @Override
        public int read(final byte[] buf, final int off, final int len)
            throws IOException
        {
            if (fillInputBuffer()) {
                int length = Math.min(inPlain.remaining(), len);
                inPlain.get(buf, off, length);
                return length;
            } else {
                return -1;
            }
        }

        @Override
        public void reset() throws IOException {
            throw new IOException("mark/reset not supported");
        }

        @Override
        public long skip(final long n) throws IOException {
            long left = n;
            while (left > 0L) {
                int read = read(SKIP_BUFFER);
                if (read == -1) {
                    break;
                } else {
                    left -= read;
                }
            }
            return n - left;
        }
    }

    private class SSLOutputStream extends OutputStream {
        @Override
        public void close() throws IOException {
            flush();
            outboundClosed = true;
            SSLSocketChannelIO.this.close();
        }

        @Override
        public void flush() throws IOException {
            writeOutputBuffer();
            flushOutputBuffer();
        }

        @Override
        public void write(final int b) throws IOException {
            if (outboundClosed) {
                throw new IOException(OUTPUT_CLOSED);
            }
            while (!outPlain.hasRemaining()) {
                writeOutputBuffer();
            }
            if (DEBUG_HTTPS) {
                info(null, "Trying to put one byte to output buffer");
            }
            outPlain.put((byte) b);
        }

        @Override
        public void write(final byte[] buf) throws IOException {
            write(buf, 0, buf.length);
        }

        @Override
        public void write(final byte[] buf, final int off, final int len)
            throws IOException
        {
            if (outboundClosed) {
                throw new IOException(OUTPUT_CLOSED);
            }
            int offset = off;
            int length = len;
            while (length > 0) {
                int remaining = outPlain.remaining();
                if (remaining > 0) {
                    int transferSize = Math.min(length, remaining);
                    if (DEBUG_HTTPS) {
                        info(
                            null,
                            "Trying to put " + transferSize
                            + " bytes to output buffer");
                    }
                    outPlain.put(buf, offset, transferSize);
                    offset += transferSize;
                    length -= transferSize;
                } else {
                    writeOutputBuffer();
                }
            }
        }
    }
}

