package ru.yandex.http.server.async;

import java.io.IOException;
import java.lang.ref.WeakReference;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.channels.SocketChannel;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.ThreadFactory;
import java.util.logging.Level;

import org.apache.http.impl.nio.reactor.ChannelEntry;
import org.apache.http.impl.nio.reactor.DefaultListeningIOReactor;
import org.apache.http.impl.nio.reactor.IOReactorConfig;
import org.apache.http.impl.nio.reactor.SessionRequestImpl;
import org.apache.http.nio.reactor.IOReactorException;
import org.apache.http.nio.reactor.ListenerEndpoint;

import ru.yandex.function.GenericAutoCloseable;
import ru.yandex.http.util.request.function.InetAddressValue;
import ru.yandex.http.util.server.HttpServer;
import ru.yandex.http.util.server.ImmutableBaseServerConfig;
import ru.yandex.http.util.server.Limiter;
import ru.yandex.http.util.server.LimiterResult;
import ru.yandex.logger.PrefixedLogger;
import ru.yandex.util.timesource.TimeSource;

public class LimitingListeningIOReactor extends DefaultListeningIOReactor {
    private static final SocketAddress FAKE_ADDRESS = new InetSocketAddress(0);

    @SuppressWarnings("JdkObsolete")
    private final List<WeakReference<ActiveConnection>> connections =
        new LinkedList<>();
    private final HttpServer<?, ?> server;
    private final Limiter connectionsLimiter;
    private final boolean rejectConnectionsOverLimit;
    private final int maxGarbageConnections;
    private final long cleanupInterval;
    private final PrefixedLogger logger;
    private long lastCleanupTime = TimeSource.INSTANCE.currentTimeMillis();
    private int addedSinceLastCleanup = 0;

    public LimitingListeningIOReactor(
        final ImmutableBaseServerConfig config,
        final ThreadFactory threadFactory,
        final HttpServer<?, ?> server)
        throws IOReactorException
    {
        super(
            IOReactorConfig
                .custom()
                .setIoThreadCount(config.workers())
                .setSoTimeout(config.timeout())
                .setSelectInterval(config.timerResolution())
                .setBacklogSize(config.backlog())
                .setSoLinger(config.linger())
                .setSoReuseAddress(true)
                .setTcpNoDelay(true)
                .build(),
            threadFactory);
        this.server = server;
        connectionsLimiter = config.connectionsLimiter();
        rejectConnectionsOverLimit = config.rejectConnectionsOverLimit();
        maxGarbageConnections = config.maxGarbageConnections();
        cleanupInterval = config.timerResolution();
        logger = config.loggers().preparedLoggers().asterisk();
    }

    private boolean cleanupConnections() {
        long start = TimeSource.INSTANCE.currentTimeMillis();
        int cleaned = 0;
        Iterator<WeakReference<ActiveConnection>> iter =
            connections.iterator();
        while (iter.hasNext()) {
            ActiveConnection conn = iter.next().get();
            if (conn == null) {
                iter.remove();
                ++cleaned;
            } else if (conn.isClosed()) {
                iter.remove();
                conn.connectionEvicted();
                ++cleaned;
            }
        }
        lastCleanupTime = TimeSource.INSTANCE.currentTimeMillis();
        addedSinceLastCleanup = 0;
        if ((cleaned > 0
            || connections.size() >= maxGarbageConnections)
            && logger.isLoggable(Level.FINE))
        {
            logger.fine("Connections clean up completed, cleaned: "
                + cleaned + '/' + connections.size()
                + ", time taken: " + (lastCleanupTime - start) + " ms");
        }
        return cleaned > 0;
    }

    @Override
    public void addChannel(final ChannelEntry entry) {
        SocketChannel channel = entry.getChannel();
        server.onIncomingConnection(channel);
        long now = TimeSource.INSTANCE.currentTimeMillis();
        if (now - lastCleanupTime > cleanupInterval
            || addedSinceLastCleanup >= maxGarbageConnections)
        {
            cleanupConnections();
        }
        InetAddressValue limiterKey = null;
        if (connectionsLimiter.perKeyLimitsEnabled()) {
            try {
                limiterKey =
                    new InetAddressValue(
                        ((InetSocketAddress) channel.getRemoteAddress())
                            .getAddress());
            } catch (IOException e) {
            }
        }
        GenericAutoCloseable<RuntimeException> resourcesReleaser = null;
        while (true) {
            LimiterResult limiterResult =
                connectionsLimiter.acquire(-1L, limiterKey);
            String limiterMessage = limiterResult.message();
            GenericAutoCloseable<RuntimeException> releaser =
                limiterResult.resourcesReleaser();
            if (limiterMessage == null) {
                resourcesReleaser = releaser;
                break;
            } else {
                if (releaser != null) {
                    releaser.close();
                }
                server.onLimitedConnection(channel, limiterMessage);
                if (rejectConnectionsOverLimit) {
                    server.discardChannel(channel);
                    return;
                }
            }
            try {
                Thread.sleep(cleanupInterval);
            } catch (InterruptedException e) {
                server.discardChannel(channel);
                Thread.currentThread().interrupt();
                return;
            }
        }
        ActiveConnection conn =
            new ActiveConnection(channel, resourcesReleaser);
        super.addChannel(
            new ChannelEntry(
                channel,
                new SessionRequestImpl(
                    FAKE_ADDRESS,
                    null,
                    conn,
                    null)));
        connections.add(new WeakReference<>(conn));
        ++addedSinceLastCleanup;
    }

    public ListenerEndpoint bind(final SocketAddress address)
        throws IOReactorException
    {
        ListenerEndpoint result = listen(address);
        processEvents(0);
        try {
            result.waitFor();
        } catch (InterruptedException e) {
            throw new IOReactorException("Endpoint binding interrupted", e);
        }
        Exception e = result.getException();
        if (e != null) {
            throw new IOReactorException("Endpoint binding failed", e);
        }
        return result;
    }
}

