package ru.yandex.ljinx;

import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.EnumMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.logging.Level;

import org.apache.http.Header;
import org.apache.http.HeaderIterator;
import org.apache.http.HttpException;
import org.apache.http.HttpHeaders;
import org.apache.http.HttpHost;
import org.apache.http.HttpRequest;
import org.apache.http.HttpResponse;
import org.apache.http.HttpStatus;
import org.apache.http.RequestLine;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.http.concurrent.FutureCallback;
import org.apache.http.conn.routing.HttpRoute;
import org.apache.http.nio.protocol.HttpAsyncExchange;
import org.apache.http.nio.protocol.HttpAsyncRequestHandler;
import org.apache.http.protocol.HttpContext;

import ru.yandex.collection.LongPair;
import ru.yandex.concurrent.TimeFrameQueue;
import ru.yandex.http.proxy.AbstractProxySessionCallback;
import ru.yandex.http.proxy.BasicProxySession;
import ru.yandex.http.proxy.HttpResponseSendingCallback;
import ru.yandex.http.proxy.ProxySession;
import ru.yandex.http.util.BadRequestException;
import ru.yandex.http.util.DuplexFutureCallback;
import ru.yandex.http.util.HeaderUtils;
import ru.yandex.http.util.HttpHostAppender;
import ru.yandex.http.util.HttpHostComparator;
import ru.yandex.http.util.RequestErrorType;
import ru.yandex.http.util.ServiceUnavailableException;
import ru.yandex.http.util.YandexHeaders;
import ru.yandex.http.util.nio.BasicAsyncRequestProducerGenerator;
import ru.yandex.http.util.nio.BasicAsyncResponseConsumerFactory;
import ru.yandex.http.util.nio.BasicAsyncResponseProducerGenerator;
import ru.yandex.http.util.nio.EntityGenerator;
import ru.yandex.http.util.nio.NByteArrayEntityGeneratorAsyncConsumer;
import ru.yandex.http.util.nio.client.AsyncClient;
import ru.yandex.http.util.nio.client.FilterRequestsListener;
import ru.yandex.http.util.nio.client.RequestsListener;
import ru.yandex.http.util.request.function.RequestFunctionValue;
import ru.yandex.parser.searchmap.ImmutableSearchMapConfig;
import ru.yandex.parser.searchmap.SearchMap;
import ru.yandex.util.string.StringUtils;
import ru.yandex.util.timesource.TimeSource;

public class ProxyPassHandler
    implements HttpAsyncRequestHandler<EntityGenerator>
{
    private static final Long ZERO = 0L;
    private static final Long ONE = 1L;

    private static final Header CACHE_HIT =
        HeaderUtils.createHeader(YandexHeaders.X_CACHE_STATUS, "HIT");
    private static final Header CACHE_MISS =
        HeaderUtils.createHeader(YandexHeaders.X_CACHE_STATUS, "MISS");
    private static final Header IGNORE_SIBLINGS =
        HeaderUtils.createHeader(YandexHeaders.X_LJINX_IGNORE_SIBLINGS, "Yes");
    private static final Function<Exception, RequestErrorType>
        ERROR_CLASSIFIER = new Function<Exception, RequestErrorType>() {
            @Override
            public RequestErrorType apply(final Exception e) {
                RequestErrorType type =
                    RequestErrorType.ERROR_CLASSIFIER.apply(e);
                if (type == RequestErrorType.HTTP) {
                    type = RequestErrorType.NON_RETRIABLE;
                }
                return type;
            }
        };
    private static final Map<CacheResponse.CacheType, Header> HIT_TYPES;

    static {
        HIT_TYPES = new EnumMap<>(CacheResponse.CacheType.class);
        for (CacheResponse.CacheType type: CacheResponse.CacheType.values()) {
            HIT_TYPES.put(
                type,
                HeaderUtils.createHeader(
                    YandexHeaders.X_CACHE_HIT_TYPE,
                    type.toString()));
        }
    }

    private final Map<RequestFunctionValue, LockedWaiters> keyLockMap =
        new ConcurrentHashMap<>();
    private final Map<HttpHost, Long> siblingLastFailure =
        new ConcurrentHashMap<>();
    private final ImmutableProxyPassConfig config;
    private final Ljinx ljinx;
    private final AsyncClient client;
    private final ImmutableSiblingsConfig siblingsConfig;
    private final AsyncClient siblingsClient;
    private final CacheStorage cacheStorage;
    private final HttpHost host;
    private final Set<String> noCacheHeadersLC;
    private final TimeFrameQueue<CacheResponse.CacheType> cacheStats;
    private final TimeFrameQueue<Long> siblingsRemovals;
    private final SearchMap searchMap;

    public ProxyPassHandler(
        final ImmutableProxyPassConfig config,
        final Ljinx ljinx,
        final String name)
        throws IOException
    {
        this.config = config;
        this.ljinx = ljinx;
        client = ljinx.client("proxy-pass-client-" + name, config);
        siblingsConfig = config.siblingsConfig();
        if (siblingsConfig == null) {
            siblingsClient = null;
        } else {
            siblingsClient = ljinx.client(
                "siblings-client-" + name,
                siblingsConfig,
                ERROR_CLASSIFIER);
        }
        host = config.host();
        cacheStorage = ljinx.cacheStorage(config.cacheStorage());
        if (cacheStorage == null) {
            throw new IOException(
                "CacheStorage with name <"
                + config.cacheStorage() + "> is not configured");
        }
        noCacheHeadersLC = new HashSet<>();
        for (String header: config.noCacheHeaders()) {
            noCacheHeadersLC.add(header.toLowerCase(Locale.ROOT));
        }
        ljinx.logger().fine(noCacheHeadersLC.toString());
        cacheStats = cacheStorage.cacheStats();
        siblingsRemovals = cacheStorage.siblingsRemovals();
        ImmutableSearchMapConfig searchMapConfig = config.searchMapConfig();
        if (searchMapConfig == null) {
            searchMap = null;
        } else {
            try {
                searchMap = searchMapConfig.build();
            } catch (ParseException e) {
                throw new IOException(e);
            }
        }
    }

    public ImmutableProxyPassConfig config() {
        return config;
    }

    public SearchMap searchMap() {
        return searchMap;
    }

    public CacheStorage cacheStorage() {
        return cacheStorage;
    }

    public void shuffleHosts(
        final List<HttpHost> hosts,
        final long seed,
        final boolean randomShuffle)
    {
        int size = hosts.size();
        int subListSeed = (int) (seed % size);
        int siblingsGroups = 1;
        long errorsTimeFrame = siblingsConfig.errorsTimeFrame();
        Consumer<List<HttpHost>> shuffler;
        if (randomShuffle) {
            shuffler = x -> Collections.shuffle(x, new Random(seed));
        } else {
            shuffler = x -> Collections.rotate(x, subListSeed);
        }
        if (size > 1 && errorsTimeFrame > 0) {
            long timeFrameStart =
                (TimeSource.INSTANCE.currentTimeMillis() - errorsTimeFrame)
                    / siblingsConfig.errorsTimeFrameGranularity();
            HostTimeFrameChecker checker =
                new HostTimeFrameChecker(timeFrameStart);
            List<LongPair<HttpHost>> rankedHosts = new ArrayList<>(size);
            for (HttpHost host: hosts) {
                Long lastFailure =
                    siblingLastFailure.computeIfPresent(host, checker);
                long timestamp;
                if (lastFailure == null) {
                    timestamp = 0;
                } else {
                    timestamp = lastFailure.longValue();
                }
                rankedHosts.add(new LongPair<>(timestamp, host));
            }
            rankedHosts.sort(RankedHostsComparator.INSTANCE);
            hosts.clear();
            LongPair<HttpHost> host = rankedHosts.get(0);
            hosts.add(host.second());
            long currentLastFailure = host.first();
            int first = 0;
            int i = 1;
            while (i < size) {
                host = rankedHosts.get(i);
                long lastFailure = host.first();
                if (lastFailure != currentLastFailure) {
                    if (first + 1 < i) {
                        shuffler.accept(hosts.subList(first, i));
                    }
                    first = i;
                    currentLastFailure = lastFailure;
                    ++siblingsGroups;
                }
                hosts.add(host.second());
                ++i;
            }
            if (first + 1 < size) {
                shuffler.accept(hosts.subList(first, size));
            }
        } else if (size > 1) {
            shuffler.accept(hosts);
        }
        if (siblingsGroups > 1) {
            siblingsRemovals.accept(ONE);
        } else {
            siblingsRemovals.accept(ZERO);
        }
    }

    public void reloadSearchMap() throws IOException, ParseException {
        if (searchMap != null && searchMap.reloadable()) {
            searchMap.reload();
        }
    }

    @Override
    public NByteArrayEntityGeneratorAsyncConsumer processRequest(
        final HttpRequest request,
        final HttpContext context)
    {
        return new NByteArrayEntityGeneratorAsyncConsumer();
    }

    @Override
    @SuppressWarnings("FutureReturnValueIgnored")
    public void handle(
        final EntityGenerator requestBody,
        final HttpAsyncExchange exchange,
        final HttpContext context)
        throws HttpException
    {
        ProxySession session = new BasicProxySession(ljinx, exchange, context);
        HttpRequest request = session.request();
        HttpHost targetHost;
        String uri;
        RequestLine requestLine = request.getRequestLine();
        String path = session.uri().rawPath();
        if (path.length() > 0 && path.charAt(0) != '/') {
            try {
                URI parsedUri = new URI(requestLine.getUri());
                targetHost = new HttpHost(
                    parsedUri.getHost(),
                    parsedUri.getPort(),
                    parsedUri.getScheme());
                String rawPath = parsedUri.getRawPath();
                String query = parsedUri.getQuery();
                if (query == null) {
                    uri = rawPath;
                } else {
                    uri = StringUtils.concat(rawPath, '?', query);
                }
            } catch (URISyntaxException e) {
                throw new BadRequestException(e);
            }
        } else {
            targetHost = null;
            uri = requestLine.getUri();
        }

        if (config.pattern() != null && config.replacement() != null) {
            uri = config.pattern().matcher(uri)
                .replaceAll(config().replacement());
        }

        String method = requestLine.getMethod();
        BasicAsyncRequestProducerGenerator producerGenerator =
            new BasicAsyncRequestProducerGenerator(uri, requestBody, method);
        producerGenerator.copyHeader(request, HttpHeaders.ACCEPT_CHARSET);
        producerGenerator.copyHeader(request, HttpHeaders.ACCEPT_ENCODING);
        for (String headerName: config.passHeaders()) {
            producerGenerator.copyHeader(request, headerName);
        }

        boolean ignoreSiblings =
            siblingsClient == null
                || request.getFirstHeader(
                    YandexHeaders.X_LJINX_IGNORE_SIBLINGS) != null;
        ProxyPassSession ppSession = new ProxyPassSession(
            this,
            session,
            uri,
            producerGenerator,
            targetHost,
            ignoreSiblings);

        if (ignoreSiblings) {
            if (request.getFirstHeader(
                YandexHeaders.X_CACHE_INVALIDATE) != null)
            {
                cacheStorage.remove(
                    ppSession,
                    new CacheRemoveCallback(session));
            } else {
                cacheGet(ppSession);
            }
        } else {
            List<HttpHost> hosts;
            try {
                hosts =
                    siblingsConfig.hosts().value(ppSession).hostListValue();
            } catch (ExecutionException e) {
                throw new BadRequestException(
                    "Failed to compute siblings hosts",
                    e);
            }
            if (hosts.isEmpty()) {
                throw new ServiceUnavailableException("No siblings found");
            }

            StringBuilder sb =
                new StringBuilder("Will go through siblings in order: ");
            for (int i = 0; i < hosts.size(); ++i) {
                if (i != 0) {
                    sb.append(',');
                    sb.append(' ');
                }
                HttpHost host = hosts.get(i);
                HttpHostAppender.appendTo(sb, host);
                sb.append(' ');
                sb.append('(');
                sb.append(siblingLastFailure.get(host));
                sb.append(')');
            }
            session.logger().info(new String(sb));

            HttpHost target = host(targetHost);
            producerGenerator.addHeader(IGNORE_SIBLINGS);
            for (String header: siblingsConfig.passHeaders()) {
                producerGenerator.copyHeader(request, header);
            }
            AsyncClient client = siblingsClient.adjust(session.context());
            FutureCallback<HttpResponse> callback =
                new HttpResponseSendingCallback(session);
            RequestsListener listener;
            if (siblingsConfig.errorsTimeFrame() > 0) {
                listener = new FailureListener(session);
            } else {
                listener = session.listener();
            }
            Supplier<? extends HttpClientContext> contextGenerator =
                listener.createContextGeneratorFor(client);
            if (targetHost == null) {
                // if request came to ljinx without proxy, keep it same way
                // we taking final host from proxy pass config, but we need to
                // resolve it on last ljinx node, because
                // it could contain BSCONFIG_IPORT
                client.execute(
                    hosts,
                    producerGenerator,
                    BasicAsyncResponseConsumerFactory.INSTANCE,
                    contextGenerator,
                    callback);
            } else {
                client.execute(
                    hosts,
                    () -> producerGenerator.apply(target),
                    BasicAsyncResponseConsumerFactory.INSTANCE,
                    contextGenerator,
                    callback);
            }
        }
    }

    private HttpHost host(final HttpHost targetHost) {
        if (targetHost == null) {
            return host;
        } else {
            return targetHost;
        }
    }

    private void cacheGet(final ProxyPassSession session) {
        cacheGet(session, new CacheResponseCallback(session));
    }

    private void cacheGet(
        final ProxyPassSession session,
        final CacheResponseCallback callback)
    {
        cacheStorage.get(session, callback);
    }

    private void proxyPass(
        final ProxyPassSession session,
        final CacheResponseCallback callback)
    {
        RequestFunctionValue cacheKey = session.cacheKey();
        LockedWaiters newWaiters = new LockedWaiters(cacheKey);
        newWaiters.add(callback);
        LockedWaiters oldWaiters =
            keyLockMap.putIfAbsent(cacheKey, newWaiters);
        if (oldWaiters != null) {
            synchronized (oldWaiters) {
                if (!oldWaiters.isEmpty()) {
                    // intern cache key, so old one can be GC'ed
                    session.cacheKey(oldWaiters.cacheKey());
                    oldWaiters.add(callback);
                    return;
                }
            }
            // old waiters is empty
            // response has just been finished
            // try get from cache again
            cacheGet(session, callback);
        } else {
            proxyPassRequest(session, new ProxyPassCallback(session));
        }
    }

    @SuppressWarnings("FutureReturnValueIgnored")
    private void proxyPassRequest(
        final ProxyPassSession ppSession,
        final FutureCallback<HttpResponse> callback)
    {
        ProxySession session = ppSession.session();
        AsyncClient client = this.client.adjust(session.context());
        client.execute(
            host(ppSession.targetHost()),
            ppSession.releaseProducerGenerator(),
            BasicAsyncResponseConsumerFactory.INSTANCE,
            session.listener().createContextGeneratorFor(client),
            callback);
    }

    private static class LockedWaiters
        extends ArrayList<FutureCallback<CacheResponse>>
    {
        private static final long serialVersionUID = 0L;

        private final transient RequestFunctionValue cacheKey;

        LockedWaiters(final RequestFunctionValue cacheKey) {
            super(2 + 2);
            this.cacheKey = cacheKey;
        }

        public RequestFunctionValue cacheKey() {
            return cacheKey;
        }
    }

    private class CacheResponseCallback
        extends AbstractProxySessionCallback<CacheResponse>
    {
        private final ProxyPassSession ppSession;

        CacheResponseCallback(final ProxyPassSession ppSession) {
            super(ppSession.session());
            this.ppSession = ppSession;
        }

        @Override
        public void completed(final CacheResponse response) {
            if (response != null) {
                ppSession.logger().info(
                    "Got from cache: " + response.response());
                cacheStats.accept(response.cacheType());
                ppSession.session().response(
                    response.response().get(
                        CACHE_HIT,
                        HIT_TYPES.get(response.cacheType())));
            } else {
                proxyPass(ppSession, this);
            }
        }
    }

    private class ProxyPassCallback
        extends AbstractProxySessionCallback<HttpResponse>
    {
        private final ProxyPassSession ppSession;

        ProxyPassCallback(final ProxyPassSession ppSession) {
            super(ppSession.session());
            this.ppSession = ppSession;
        }

        private List<FutureCallback<CacheResponse>> waitersCopyOrNull() {
            LockedWaiters waiters = keyLockMap.remove(ppSession.cacheKey());
            List<FutureCallback<CacheResponse>> waitersCopy;
            synchronized (waiters) {
                if (waiters.size() > 1) {
                    waitersCopy = new ArrayList<>(waiters);
                } else {
                    waitersCopy = null;
                }
                waiters.clear();
            }
            return waitersCopy;
        }

        @Override
        public void completed(final HttpResponse response) {
            BasicAsyncResponseProducerGenerator responseGenerator;
            try {
                responseGenerator =
                    new BasicAsyncResponseProducerGenerator(response);
            } catch (IOException e) {
                failed(new ServiceUnavailableException(
                    "Unexpected error while trying to copy response", e));
                return;
            }

            final int httpCode = response.getStatusLine().getStatusCode();
            Long ttl = config.cacheCodesTTL().get(httpCode);
            boolean cache = true;
            if (ttl == null && !config.cacheCodes().contains(httpCode)) {
                ppSession.logger().info("HTTP code not cacheable: " + httpCode);
                cache = false;
            } else if (!noCacheHeadersLC.isEmpty()) {
                HeaderIterator headerIter = response.headerIterator();
                while (headerIter.hasNext()) {
                    Header header = headerIter.nextHeader();
                    if (noCacheHeadersLC.contains(
                        header.getName().toLowerCase(Locale.ROOT)))
                    {
                        ppSession.logger().info(
                            "matched non cacheable header: " + header);
                        cache = false;
                        break;
                    }
                }
            }
            if (cache) {
                ppSession.ttl(ttl);
                ppSession.logger().info("Would cache");
                cacheStorage.put(
                    ppSession,
                    responseGenerator,
                    new CachePutCallback(ppSession.session()));
            }
            List<FutureCallback<CacheResponse>> waitersCopy =
                waitersCopyOrNull();
            if (waitersCopy != null) {
                //skip first (ourself)
                final CacheResponse cacheResponse =
                    new CacheResponse(
                        responseGenerator,
                        CacheResponse.CacheType.LOCK,
                        TimeSource.INSTANCE.currentTimeMillis());
                for (int i = 1; i < waitersCopy.size(); ++i) {
                    waitersCopy.get(i).completed(cacheResponse);
                }
            }
            cacheStats.accept(CacheResponse.CacheType.MISS);
            ppSession.session().response(
                responseGenerator.get(
                    CACHE_MISS,
                    HIT_TYPES.get(CacheResponse.CacheType.MISS)));
        }

        @Override
        public void failed(final Exception e) {
            List<FutureCallback<CacheResponse>> waitersCopy =
                waitersCopyOrNull();
            if (waitersCopy != null) {
                //skip first (ourself)
                for (int i = 1; i < waitersCopy.size(); ++i) {
                    waitersCopy.get(i).failed(e);
                }
            }
            super.failed(e);
        }
    }

    private static class CachePutCallback
        extends AbstractProxySessionCallback<Void>
    {
        CachePutCallback(final ProxySession session) {
            super(session);
        }

        @Override
        public void completed(final Void v) {
            session.logger().info("Cache put completed");
        }

        @Override
        public void failed(final Exception e) {
            session.logger().log(Level.SEVERE, "Cache put failed", e);
        }
    }

    private static class CacheRemoveCallback
        extends AbstractProxySessionCallback<Void>
    {
        CacheRemoveCallback(final ProxySession session) {
            super(session);
        }

        @Override
        public void completed(final Void v) {
            session.logger().info("Cache remove completed");
            session.response(HttpStatus.SC_OK);
        }

        @Override
        public void failed(final Exception e) {
            session.logger().log(Level.SEVERE, "Cache remove failed", e);
            super.failed(e);
        }
    }

    private static class HostTimeFrameChecker
        implements BiFunction<HttpHost, Long, Long>
    {
        private final long timeFrameStart;

        HostTimeFrameChecker(final long timeFrameStart) {
            this.timeFrameStart = timeFrameStart;
        }

        @Override
        public Long apply(final HttpHost host, final Long lastFailure) {
            if (lastFailure.longValue() < timeFrameStart) {
                return null;
            } else {
                return lastFailure;
            }
        }
    }

    private enum RankedHostsComparator
        implements Comparator<LongPair<HttpHost>>
    {
        INSTANCE;

        @Override
        public int compare(
            final LongPair<HttpHost> lhs,
            final LongPair<HttpHost> rhs)
        {
            int cmp = Long.compare(lhs.first(), rhs.first());
            if (cmp == 0) {
                cmp = HttpHostComparator.INSTANCE.compare(
                    lhs.second(),
                    rhs.second());
            }
            return cmp;
        }
    }

    private class FailureListener extends FilterRequestsListener {
        private final ProxySession session;

        FailureListener(final ProxySession session) {
            super(session.listener());
            this.session = session;
        }

        @Override
        public <T> FutureCallback<T> createCallbackFor(
            final HttpRoute route,
            final HttpRequest request,
            final HttpContext context)
        {
            return new DuplexFutureCallback<>(
                new FailureCallback<>(session, route),
                super.createCallbackFor(route, request, context));
        }
    }

    private class FailureCallback<T> implements FutureCallback<T> {
        private final ProxySession session;
        private final HttpHost target;

        FailureCallback(final ProxySession session, final HttpRoute route) {
            this.session = session;
            if (route.getProxyHost() != null) {
                target = route.getProxyHost();
            } else {
                target = route.getTargetHost();
            }
        }

        @Override
        public void cancelled() {
        }

        private void recordFailure() {
            long timestamp = TimeSource.INSTANCE.currentTimeMillis()
                / siblingsConfig.errorsTimeFrameGranularity();
            StringBuilder sb = new StringBuilder("Recording ");
            HttpHostAppender.appendTo(sb, target);
            sb.append(" failure at ");
            sb.append(timestamp);
            session.logger().warning(new String(sb));
            siblingLastFailure.put(target, timestamp);
        }

        @Override
        public void failed(final Exception e) {
            recordFailure();
        }

        @Override
        public void completed(final T result) {
            if (result instanceof HttpResponse) {
                int httpCode =
                    ((HttpResponse) result).getStatusLine().getStatusCode();
                if (!config.cacheCodes().contains(httpCode)
                    && !config.goodCodes().contains(httpCode))
                {
                    recordFailure();
                }
            }
        }
    }
}
