package ru.yandex.msearch.collector;

import java.io.IOException;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.logging.Level;

import org.apache.lucene.index.IndexReader;

import ru.yandex.msearch.HnswSearchHandler;
import ru.yandex.msearch.SearchRequest;
import ru.yandex.msearch.SearchResultsConsumer;
import ru.yandex.msearch.collector.docprocessor.ModuleFieldsAggregator;
import ru.yandex.msearch.fieldscache.CacheInput;
import ru.yandex.msearch.fieldscache.FieldsCache;
import ru.yandex.msearch.knn.hnsw.HnswIndex;
import ru.yandex.msearch.knn.util.disk.DiskDistComparator;
import ru.yandex.msearch.knn.util.disk.DiskDistance;
import ru.yandex.msearch.knn.util.disk.DiskHnswIndex;
import ru.yandex.msearch.knn.util.disk.DiskItem;
import ru.yandex.msearch.knn.util.disk.DocLoader;
import ru.yandex.msearch.knn.util.disk.HnswCache;
import ru.yandex.search.prefix.Prefix;

import static ru.yandex.msearch.knn.Index.SearchResult;

public class HnswSortedPseudoCollector extends SortedCollector {
    private static final int M = 10; // number of bidirectional links
    private static final int EF = 40; // size of the dynamic list of nearest neighbours
    private static final int MAX_ITEM_COUNT = 10000001;
    private static final int N_NEIGHBOURS = 10;

    private static final String ID = "id";

    private final HnswCache hnswCache;
    private final Prefix prefix;
    private final int dimensions;
    private final boolean isNew;
    private final byte[] i2tKeyword;
    private final Integer maxDistance;
    private DiskHnswIndex index;
    private final int fieldIndex;
    private final int idField;

    private long totalCollectTime;

    public HnswSortedPseudoCollector(
            Integer maxDistance,
            int fieldIndex,
            byte[] i2tKeyword,
            SearchRequest request,
            SearchResultsConsumer consumer,
            HnswCache hnswCache) {
        super(request, consumer);
        this.dimensions = i2tKeyword.length;
        this.maxDistance = maxDistance;
        this.fieldIndex = fieldIndex;
        this.i2tKeyword = i2tKeyword;
        this.idField = request.fieldToIndex().indexFor(ID);
        this.hnswCache = hnswCache;
        this.prefix = request.prefixes().iterator().next();
        isNew = setOrCreate(hnswCache);
        request.ctx().logger().info(
                "Found" + (isNew ? " no " : " ")
                        + "hnsw index in cache"
                        + (!isNew ? " with " + index.size() + " nodes" : ""));
    }

    private boolean setOrCreate(HnswCache cache) {
        DiskHnswIndex index = cache.get(prefix);
        if (index != null) {
            this.index = index;
            return false;
        }
        this.index = new DiskHnswIndex(
                HnswIndex.newBuilder(
                                dimensions,
                                DiskDistance.INSTANCE,
                                DiskDistComparator.INSTANCE,
                                MAX_ITEM_COUNT
                        )
                        .withM(M)
                        .withEf(EF)
                        .build());
        return true;
    }

    public void setReader(IndexReader reader) {
        this.reader = reader;
    }

    protected int buildHnsw(
            final IndexReader reader,
            final List<SortedCollector.Collectable> getDocs,
            final FieldsCache fieldsCache)
    {
        if (getDocs.isEmpty()) {
            ctx.logger().info("HnswSortedCollector: no doc found in this leaf");
            return 0;
        }

        List<CacheInput> caches = fieldsCache == null ? null : fieldsCache.getCachesFor(reader, step1ReadFields);
        if (caches == null) {
            ctx.logger().fine("useCache = false");
            long sortStart = System.currentTimeMillis();
            if (ctx.logger().isLoggable(Level.FINE)) {
                ctx.logger().fine(
                        "HnswSortedCollector: populateDocs: "
                                + "Collections.sort(getDocs) took "
                                + (System.currentTimeMillis() - sortStart) + " ms");
            }
        }

        DocLoader loader = new DocLoader(step1Visitor, reader, caches);

        ctx.logger().info("Initializing hnsw index with " + getDocs.size() + " docs");
        int totalAdded = 0;
        for (int i = 0; i < getDocs.size(); i++) {
            final SortedCollector.Collectable collected = getDocs.get(i);
            final SortedCollector.YaDoc3Delayed yadoc = collected.doc();
            loader.apply(yadoc);
            if (!collected.processed) {
                String id = yadoc.getField(idField).toString();
                byte[] vector = vectorFromDoc(yadoc, fieldIndex);
                DiskItem item = new DiskItem(id, vector, getFields, yadoc, fieldToIndex);
                int curAdded = (index.add(item) ? 1 : 0);
                totalAdded += curAdded;
                if (curAdded == 0) {
                    ctx.logger().info("Failed to add vector with dimensionality " + vector.length);
                }
            }
            collected.processed = true;
            getDocs.set(i, null);
        }
        getDocs.clear();
        ctx.logger().info("Added " + totalAdded + " nodes");
        return totalAdded;
    }

    public List<SearchResult<DiskItem, Integer>> findRelevant()
            throws IOException {
        long collectStartTime = System.currentTimeMillis();

        List<SearchResult<DiskItem, Integer>> items;
        if (maxDistance == null) {
            items = index.findNearest(i2tKeyword, (int) N_NEIGHBOURS);
        } else {
            items = index.findNearest(i2tKeyword, (Integer) maxDistance);
        }
        items.sort(Comparator.comparingInt(x -> - x.getDistance()));
        ctx.logger().info("HnswSortedCollector: Found " + items.size() + " items");

        final long collectTime =
                (System.currentTimeMillis() - collectStartTime);
        if (ctx.logger().isLoggable(Level.FINE)) {
            ctx.logger().fine("HnswSortedCollector::Collect time: " + collectTime);
        }
        totalCollectTime += collectTime;

        return items;
    }

        @Override
    protected Set<String> fieldsFromFieldsFunction(FieldsFunction function) {
        Set<String> fields = new HashSet<>(function.loadFields());
        fields.addAll(getFields);
        return fields;
    }

    @Override
    protected Set<String> fieldsFromFieldsAggregator(ModuleFieldsAggregator aggregator) {
        Set<String> fields = new HashSet<>(aggregator.loadFields());
        fields.addAll(getFields);
        return fields;
    }

    @Override
    public void close() throws IOException {
        List<SearchResult<DiskItem, Integer>> items = findRelevant();
        HnswSearchHandler.HnswSearchResultProducer consumer = (HnswSearchHandler.HnswSearchResultProducer) this.consumer;
        if (consumer != null) {
            consumer.uniqHitsCount(uniqCount());
            consumer.totalHitsCount(getTotalCount());
            consumer.startHits();
            int pos = 0;
            int added = 0;
            for (SearchResult<DiskItem, Integer> item: items) {
                if (added >= length) {
                    break;
                }
                if (pos >= offset) {
                    consumer.document(item);
                    ++added;
                }
                pos++;
            }
            consumer.endHits();
        }
        if (ctx.logger().isLoggable(Level.CONFIG)) {
            ctx.logger().config("Total collect time: " + totalCollectTime);
        }
        if (isNew) {
            hnswCache.put(prefix, index);
        }
    }

    @Override
    protected int populateDocs(
            final IndexReader reader,
            final List<Collectable> getDocs,
            final FieldsCache fieldsCache)
            throws IOException
    {
        return buildHnsw(reader, getDocs, fieldsCache);
    }

    @Override
    public void flush() {
        throw new UnsupportedOperationException();
    }

    public boolean isNew() {
        return isNew;
    }

    private static byte[] vectorFromDoc(SortedCollector.YaDoc3Delayed doc, int fieldIndex) {
        YaField rawField = doc.getField(fieldIndex);
        YaField.ByteArrayYaField field = (YaField.ByteArrayYaField) rawField;
        if (field == null) {
            return new byte[0];
        }
        return field.getValue();
    }
}
