package ru.yandex.msearch;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.UncheckedIOException;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;

import org.apache.http.HttpException;
import org.apache.http.HttpRequest;
import org.apache.http.HttpResponse;
import org.apache.http.HttpStatus;
import org.apache.http.entity.ContentProducer;
import org.apache.http.entity.ContentType;
import org.apache.http.entity.EntityTemplate;
import org.apache.http.protocol.HttpContext;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.StringHelper;

import ru.yandex.collection.PatternMap;
import ru.yandex.http.server.sync.JsonContentProducerWriter;
import ru.yandex.http.util.CharsetUtils;
import ru.yandex.http.util.YandexHeaders;
import ru.yandex.http.util.request.RequestInfo;
import ru.yandex.json.writer.JsonType;
import ru.yandex.json.writer.JsonTypeExtractor;
import ru.yandex.json.writer.Utf8JsonValue;
import ru.yandex.json.writer.Utf8JsonWriter;
import ru.yandex.msearch.collector.HnswSortedPseudoCollector;
import ru.yandex.msearch.collector.YaField;
import ru.yandex.msearch.collector.outergroup.OuterGroupFunctionFactory;
import ru.yandex.msearch.knn.util.disk.DiskItem;
import ru.yandex.msearch.knn.util.disk.HnswCache;
import ru.yandex.msearch.util.Compress;
import ru.yandex.msearch.util.IOStater;
import ru.yandex.parser.uri.CgiParams;
import ru.yandex.search.prefix.Prefix;
import ru.yandex.util.string.UnhexStrings;

import static ru.yandex.msearch.knn.Index.SearchResult;

public class HnswSearchHandler extends SearchHandler {
    private static final String FILTER = "filter_cmp";

    private final HnswCache hnswCache = new HnswCache();

    public HnswSearchHandler(
            DatabaseManager dbManager,
            Config config,
            OuterGroupFunctionFactory outerGroupFunctionFactory,
            PatternMap<RequestInfo, IOStater> ioStaters) {
        super(dbManager, config, outerGroupFunctionFactory, ioStaters);
    }

    @Override
    public void handle(
            HttpRequest request,
            HttpResponse response,
            HttpContext context) throws HttpException, IOException {
        CgiParams params = new CgiParams(request);

        List<String> dpParams = params.get("dp");
        List<String> clearCache = params.get("clear");

        // TODO: remove
        if (clearCache != null) {
            hnswCache.clear();
            return;
        }

        if (dpParams == null) {
            super.handle(request, response, context);
            return;
        }

        Index index = dbManager.indexOrException(params, BRE_GEN);

        SearchRequest searchRequest =
            new NewSearchRequest(
                params,
                context,
                index,
                index.config(),
                outerGroupFunctionFactory);

        String first = dpParams.get(0);
        String firstParams = first.substring(first.indexOf("(") + 1, first.indexOf(")"));
        String[] splitFirst = firstParams.split(" ");
        String fieldName = splitFirst[0];
        String keyword = splitFirst[1];
        Integer filterVal = null;

        if (dpParams.size() > 1) {
            String second = dpParams.get(1);
            int bracketIndex = second.indexOf("(");
            if (second.substring(0, bracketIndex).equals(FILTER)) {
                String secondParams = second.substring(second.indexOf("(") + 1, second.indexOf(")"));
                String[] splittedFilter = secondParams.split(",");
                String filterField = splittedFilter[2];
                try {
                    filterVal = Integer.parseInt(filterField);
                } catch (NumberFormatException ignored) {}
            }
        }

        byte[] i2tKeyword = UnhexStrings.unhex(keyword);
        int fieldIndex = searchRequest.fieldToIndex().indexFor(StringHelper.intern(fieldName));

        Charset charset = CharsetUtils.acceptedCharset(request);

        EntityTemplate entity;
        SearchResultsConsumer consumer;
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        if (charset.equals(StandardCharsets.UTF_8)) {
            HnswSearchResultProducer producer =
                new HnswSearchResultProducer(
                    JsonTypeExtractor.NORMAL.extract(params),
                    searchRequest.skipNulls(),
                    out);
            entity = new HnswHandlerEntity(producer, out);
            consumer = producer;
            producer.setWriter(new Utf8JsonWriter(out));
        } else {
            SearchResultsProducer producer =
                new SearchResultsProducer(
                    null,
                    searchRequest.skipNulls());
            entity =
                new EntityTemplate(
                    new JsonContentProducerWriter(
                        producer,
                        JsonTypeExtractor.NORMAL.extract(params),
                        charset));
            consumer = producer;
        }
        entity.setChunked(true);
        entity.setContentType(
            ContentType.APPLICATION_JSON.withCharset(charset).toString());

        consumer.startResults();
        Compress.resetStats();

        Prefix prefix = searchRequest.prefixes().iterator().next();

        Searcher searcher;
        try {
            searcher = index.getSearcher(
                    prefix, true);
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }

        HnswSortedPseudoCollector collector =
            new HnswSortedPseudoCollector(
                filterVal,
                fieldIndex,
                i2tKeyword,
                searchRequest,
                consumer,
                hnswCache
            );

        if (searchRequest.updatePrefixActivity()) {
            index.updatePrefixActivity(prefix);
        }
        if (searcher == null) {
            searcher = index.getSearcher(
                prefix,
                searchRequest.syncSearcher());
        }

        collector.setPrefix(prefix);
        collector.setFieldsCache(index.fieldsCache());
        IndexSearcher indexSearcher = searcher.searcher();
        Query query = queryWithPrefix(prefix);

        IndexReader.AtomicReaderContext[] leaves = indexSearcher.getTopReaderContext().leaves();
        Weight.ScorerContext scorerContext =
            Weight.ScorerContext.def().scoreDocsInOrder(true).topScorer(true);
        Weight weight = query.weight(indexSearcher);
        if (collector.isNew()) {
            for (int i = 0; i < leaves.length; i++) {
                Scorer scorer = weight.scorer(leaves[i], scorerContext);
                collector.setReader(leaves[i].reader);
                if (scorer != null) {
                    try {
                        scorer.score(collector);
                    } finally {
                        scorer.close();
                    }
                }
                collector.setNextReader(leaves[i]);
                scorerContext = scorerContext.scoreDocsInOrder(
                        !collector.acceptsDocsOutOfOrder());
            }
        }
        collector.close();

        int count = collector.uniqCount();
        if (searchRequest.ctx().logger().isLoggable(Level.INFO)) {
            searchRequest.ctx().logger().info("Total docs found: " + count);
        }
        consumer.endResults();
        searcher.free();

        entity.setChunked(true);
        entity.setContentType(
            ContentType.APPLICATION_JSON.withCharset(charset).toString());
        final String service = params.getString("service", null);
        if (service != null && searchRequest.prefixes().size() == 1) {
            final QueueShard queueShard = new QueueShard(service, prefix);
            response.setHeader(
                YandexHeaders.ZOO_QUEUE_ID,
                Long.toString(index.queueId(queueShard, false)));
        }

        response.setEntity(entity);
        response.setStatusCode(HttpStatus.SC_OK);
    }

    private static final class HnswHandlerEntity extends EntityTemplate {
        private final ByteArrayOutputStream out;
        public HnswHandlerEntity(ContentProducer producer, ByteArrayOutputStream out) {
            super(producer);
            this.out = out;
        }

        @Override
        public void writeTo(OutputStream outStream) throws IOException {
            outStream.write(out.toByteArray());
        }
    }

    private static Query queryWithPrefix(Prefix prefix) {
        return new TermQuery(new Term("has_i2t", prefix.toStringFast() + "#1"));
    }

    public static class HnswSearchResultProducer extends Utf8SearchResultProducer {

        public HnswSearchResultProducer(
                final JsonType jsonType,
                final boolean skipNulls,
                final OutputStream out)
        {
            super(jsonType, skipNulls, out);
        }

        public void document(
                SearchResult<DiskItem, Integer> searchResult) throws IOException
        {
            writer.startObject();

            DiskItem item = searchResult.getItem();
            for (Map.Entry<String, YaField> fieldEntry: item.getFields().entrySet()) {
                writer.key(fieldEntry.getKey());
                writer.value((Utf8JsonValue) fieldEntry.getValue());
            }
            writer.endObject();
        }
    }
}
