package ru.yandex.msearch;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;

import java.nio.charset.StandardCharsets;

import java.util.List;

import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryProducer;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.util.BytesRef;

import ru.yandex.util.string.StringUtils;

public class SlimDownQueryHelper {
    private static final int TERM_QUERY_ID = 0;
    private static final int BOOLEAN_QUERY_ID = 1;
    private static final BooleanClause.Occur[] OCCURS =
        BooleanClause.Occur.values();

    private static final ThreadLocal<ByteArrayOutputStream> OUT =
        ThreadLocal.withInitial(() -> new ByteArrayOutputStream());

    public static QueryProducer slimDownQuery(
        final QueryProducer queryProducer)
    {
        Query query = queryProducer.produceQuery();
        byte[] serialized = null;
        try {
            serialized = serializeQuery(query);
        } catch (IOException cannotBe) {
        }
        if (serialized == null) {
            return queryProducer;
        } else {
            return new DeserializingQueryProducer(serialized, query);
        }
    }

    private static byte[] serializeQuery(final Query query) throws IOException {
        ByteArrayOutputStream out = OUT.get();
        out.reset();
        if (serializeQuery(query, out)) {
            return out.toByteArray();
        } else {
            return null;
        }
    }

    private static boolean serializeQuery(
        final Query query,
        final OutputStream out)
        throws IOException
    {
        if (query instanceof TermQuery) {
            return serializeTermQuery((TermQuery) query, out);
        } else if (query instanceof BooleanQuery) {
            return serializeBooleanQuery((BooleanQuery) query, out);
        } else {
            return false;
        }
    }

    private static boolean serializeTermQuery(
        final TermQuery query,
        final OutputStream out)
        throws IOException
    {
        writeInt(out, TERM_QUERY_ID);
        final Term term = query.getTerm();
        final String field = term.field();
        final BytesRef bytes = term.bytes();
        final byte[] fieldBytes = field.getBytes(StandardCharsets.UTF_8);
        writeInt(out, fieldBytes.length);
        out.write(fieldBytes);
        writeInt(out, bytes.length);
        out.write(bytes.bytes, bytes.offset, bytes.length);
        return true;
    }

    private static boolean serializeBooleanQuery(
        final BooleanQuery query,
        final OutputStream out)
        throws IOException
    {
        writeInt(out, BOOLEAN_QUERY_ID);
        writeBoolean(out, query.isCoordDisabled());
        writeInt(out, query.getMinimumNumberShouldMatch());
        List<BooleanClause> clauses = query.clauses();
        writeInt(out, clauses.size());
        for (int i = 0; i < clauses.size(); i++) {
            BooleanClause clause = clauses.get(i);
            writeInt(out, clause.getOccur().ordinal());
            if (!serializeQuery(clause.getQuery(), out)) {
                return false;
            }
        }
        return true;
    }

    private static void writeInt(final OutputStream out, final int i)
        throws IOException
    {
        out.write((byte) (i >> 24));
        out.write((byte) (i >> 16));
        out.write((byte) (i >>  8));
        out.write((byte) i);
    }

    private static void writeBoolean(final OutputStream out, final boolean b)
        throws IOException
    {
        if (b) {
            out.write((byte) 1);
        } else {
            out.write((byte) 0);
        }
    }

    private static int readInt(final byte[] in, final int[] off) {
        return ((in[off[0]++] & 0xFF) << 24) | ((in[off[0]++] & 0xFF) << 16)
            | ((in[off[0]++] & 0xFF) <<  8) | (in[off[0]++] & 0xFF);
    }

    private static boolean readBoolean(final byte[] in, final int[] off) {
        return in[off[0]++] == (byte) 1;
    }

    private static Query deserializeQuery(final byte[] data, final int[] off) {
        int queryClassId = readInt(data, off);
        switch (queryClassId) {
            case TERM_QUERY_ID:
                return deserializeTermQuery(data, off);
            case BOOLEAN_QUERY_ID:
                return deserializeBooleanQuery(data, off);
            default:
                throw new RuntimeException("Unexpected query class id: "
                    + queryClassId);
        }
    }

    private static Query deserializeBooleanQuery(
        final byte[] data,
        final int[] off)
    {
        boolean disableCoord = readBoolean(data, off);
        int minimumNumberShouldMatch = readInt(data, off);
        int clauses = readInt(data, off);
        BooleanQuery query = new BooleanQuery(disableCoord);
        query.setMinimumNumberShouldMatch(minimumNumberShouldMatch);
        for (int i = 0; i < clauses; i++) {
            int ordinal = readInt(data, off);
            BooleanClause.Occur occur = OCCURS[ordinal];
            query.add(deserializeQuery(data, off), occur);
        }
        return query;
    }

    private static TermQuery deserializeTermQuery(
        final byte[] data,
        final int[] off)
    {
        int fieldLen = readInt(data, off);
        String field =
            new String(data, off[0], fieldLen, StandardCharsets.UTF_8).intern();
        off[0] += fieldLen;
        int valueLen = readInt(data, off);
        BytesRef valueRef = new BytesRef(data, off[0], valueLen);
        off[0] += valueLen;
        return new TermQuery(new Term(field, valueRef));
    }

    static class DeserializingQueryProducer implements QueryProducer {
        private final byte[] data;
        private Query cachedQuery;

        DeserializingQueryProducer(final byte[] data, final Query query) {
            this.data = data;
            this.cachedQuery = query;
        }

        @Override
        public Query produceQuery() {
            if (cachedQuery == null) {
                cachedQuery = deserializeQuery(data, new int[] {0});
            }
            return cachedQuery;
        }

        public void clearCached() {
            cachedQuery = null;
        }

        @Override
        public String toString() {
            return StringUtils.concat(getClass().getSimpleName(), '@', String.valueOf(cachedQuery));
        }
    }
}

