package ru.yandex.search.so;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.logging.Level;

import com.github.jelmerk.knn.SearchResult;
import com.github.jelmerk.knn.hnsw.HnswIndex;
import org.apache.http.HttpRequest;
import org.apache.http.concurrent.FutureCallback;

import ru.yandex.collection.Pattern;
import ru.yandex.concurrent.LifoWaitBlockingQueue;
import ru.yandex.http.proxy.ProxyRequestHandlerAdapter;
import ru.yandex.http.server.async.DelegatedHttpAsyncRequestHandler;
import ru.yandex.http.util.AbstractFilterFutureCallback;
import ru.yandex.http.util.ServiceUnavailableException;
import ru.yandex.http.util.request.RequestHandlerMapper;
import ru.yandex.search.proxy.universal.UniversalSearchProxy;
import ru.yandex.stater.Stater;
import ru.yandex.stater.StatsConsumer;

public class Knn extends UniversalSearchProxy<ImmutableKnnConfig> {
    private static final int INDEX_SIZE = 1000;

    private final IndexThread indexThread;
    private HnswIndex<String, int[], MinHashMessage, Float> index = null;
    private final ReentrantReadWriteLock indexLock = new ReentrantReadWriteLock();
    @SuppressWarnings("HidingField")
    private final ThreadPoolExecutor executor;

    public Knn(final ImmutableKnnConfig config)
        throws IOException
    {
        super(config);
        executor = new ThreadPoolExecutor(
            config.searchThreads(),
            config.searchThreads(),
            1,
            TimeUnit.HOURS,
            new LifoWaitBlockingQueue<Runnable>(config.searchQueueSize()),
            new TasksRejector());

        register(
            new Pattern<>("/knn/add_point", true),
            new AddPointHandler(this),
            RequestHandlerMapper.POST);
        register(
            new Pattern<>("/knn/delete_point", true),
            new DeletePointHandler(this),
            RequestHandlerMapper.POST);
        register(
            new Pattern<>("/knn/flush", true),
            new DelegatedHttpAsyncRequestHandler<HttpRequest>(
                new ProxyRequestHandlerAdapter(new FlushHandler(this), this),
                this,
                executor),
            RequestHandlerMapper.GET);
        register(
            new Pattern<>("/knn/neighbors", true),
            new NeighborsHandler(this),
            RequestHandlerMapper.POST);
        indexThread = new IndexThread();
        registerStater(new IndexStater());
    }

    @Override
    public void start() throws IOException {
        if (config.snapshotPath().exists()) {
            logger().severe("Loading snapshot: " + config.snapshotPath());
            try (BufferedInputStream is =
                new BufferedInputStream(
                    new FileInputStream(config.snapshotPath())))
            {
                index = HnswIndex.load(is);
                index.removeEnabled(true);
            } catch (Exception e) {
                logger().log(Level.SEVERE, "Index load failed", e);
                throw e;
            }
            logger().severe("Loading completed, index size: " + index.size());
        } else {
            index =
                HnswIndex.newBuilder(MinHashDistance.INSTANCE, INDEX_SIZE)
                    .withM(config.hnswM())
                    .withEf(config.hnswEf())
                    .withEfConstruction(config.hnswEfConstruction())
                    .withRemoveEnabled()
                    .build();
        }
        indexThread.start();
        super.start();
    }

    @Override
    public void close() throws IOException {
        flush();
    }

    public void addPoint(
        final MinHashMessage msg,
        final FutureCallback<Integer> callback)
    {
        indexThread.indexQueue.add(
            new CallbackWithMessage(callback, msg, false));
    }

    public void deletePoint(
        final MinHashMessage msg,
        final FutureCallback<Integer> callback)
    {
        indexThread.indexQueue.add(
            new CallbackWithMessage(callback, msg, true));
    }

    public int indexSize() {
        return index.size();
    }

    public List<SearchResult<MinHashMessage, Float>> neighbors(
        final int[] vector,
        final int k)
    {
        return index.findNearest(vector, k);
    }

    public void neighbors(
        final int[] vector,
        final int k,
        final FutureCallback<
            List<SearchResult<MinHashMessage, Float>>> callback)
    {
        if (callback instanceof Runnable) {
            executor.execute((Runnable) callback);
        }
    }

    public void flush() throws IOException {
        File tmpFile = new File(
            config.snapshotPath().getCanonicalPath() + ".tmp");
        try (BufferedOutputStream os =
            new BufferedOutputStream(
                new FileOutputStream(tmpFile)))
        {
            indexLock.writeLock().lock();
            try {
                index.save(os);
            } finally {
                indexLock.writeLock().unlock();
            }
            tmpFile.renameTo(config.snapshotPath());
        }
    }

    private final class IndexThread extends Thread {
        private LinkedBlockingQueue<CallbackWithMessage> indexQueue =
            new LinkedBlockingQueue<>();

        IndexThread() {
            super("IndexThread");
            setDaemon(true);
        }

        @Override
        public void run() {
            while (true) {
                try {
                    CallbackWithMessage cbMsg = indexQueue.take();
                    if (cbMsg == null) {
                        continue;
                    }
                    indexLock.readLock().lock();
                    try {
                        if (cbMsg.delete) {
                            index.remove(cbMsg.msg.id(), 0);
                        } else {
                            index.add(cbMsg.msg);
                        }
                    } finally {
                        indexLock.readLock().unlock();
                        cbMsg.completed(indexSize());
                    }
                } catch (Throwable t) {
                    logger().log(Level.SEVERE, "Indexing error", t);
                }
            }
        }
    }

    private static class CallbackWithMessage
        extends AbstractFilterFutureCallback<Integer, Integer>
    {
        private final MinHashMessage msg;
        private final boolean delete;

        CallbackWithMessage(
            final FutureCallback<Integer> callback,
            final MinHashMessage msg,
            final boolean delete)
        {
            super(callback);
            this.msg = msg;
            this.delete = delete;
        }

        @Override
        public void completed(final Integer result) {
            if (callback != null) {
                callback.completed(result);
            }
        }
    }

    private static class TasksRejector implements RejectedExecutionHandler {
        @Override
        public void rejectedExecution(
            final Runnable r,
            final ThreadPoolExecutor executor)
        {
            if (r instanceof FutureCallback) {
                ((FutureCallback<?>) r).failed(
                    new ServiceUnavailableException("Search queue full"));
            }
        }
    }

    private class IndexStater implements Stater {
        @Override
        public <E extends Exception> void stats(
            final StatsConsumer<? extends E> statsConsumer)
            throws E
        {
            statsConsumer.stat(
                "index-size_ammm",
                indexSize());
        }
    }
}
