package ru.yandex.msearch.collector.docprocessor;

import java.io.IOException;
import java.text.ParseException;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

import java.util.Locale;
import java.util.Queue;
import java.util.Set;

import java.util.logging.Level;

import org.apache.lucene.document.Document;
import org.apache.lucene.document.FieldSelector;
import org.apache.lucene.document.MapFieldSelector;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;

import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.StringHelper;

import ru.yandex.http.util.BadRequestException;
import ru.yandex.msearch.ProcessorRequestContext;
import ru.yandex.msearch.Searcher;

import ru.yandex.msearch.collector.YaDoc3;
import ru.yandex.msearch.collector.YaField;

import ru.yandex.msearch.fieldscache.CacheInput;
import ru.yandex.msearch.fieldscache.FieldsCache;

import ru.yandex.search.prefix.Prefix;
import ru.yandex.util.string.StringUtils;

public class TreeCalcDocProcessor implements DocProcessor {
    private static final int MAX_TREE_HEIGHT = 1000000;
    private static final String FORMAT =
        "tree_calc(parent_id_field,id_field,extra_condition,"
            + "aggr_func,get_field+out_field)";
    private static final List<CacheInput> EMPTY_CACHE = Collections.emptyList();

    private final String parentIdField;
    private final int idFieldIndex;
    private final String idField;
    private final AggFunc aggFunc;
    private final int maxTreeHeight;
    protected final ProcessorRequestContext context;

    private final Set<String> loadFields;
    private final Prefix prefix;
    private final FieldSelector fieldSelector;
    private final Queue<String> nodesQueue;
    private Searcher searcher;
    private final FieldsCache fieldsCache;
    private HashMap<IndexReader, List<CacheInput>> cacheMap;
    private int reads = 0;
    private int docsFromCache = 0;

    private enum ParseState {
        PARENT_ID_FIELD,
        ID_FIELD,
        AGGREGATE_FUNC,
        AGGREGATE_FUNC_ARG,
        END
    }

    private enum SumCountParseState {
        GET_FIELD,
        FILTER_FIELD,
        FILTER_VALUE,
        OUT_FIELD_SUM,
        OUT_FIELD_COUNT,
        END
    }

    /**
     * tree_root(parent_id_field,id_field,aggr_func)
     * agg_func concat,sum
     * @param args
     * @param context
     * @throws ParseException
     */
    public TreeCalcDocProcessor(
        final String args,
        final ProcessorRequestContext context)
        throws ParseException
    {
        this.context = context;

        int argStart = 0;
        boolean escape = false;

        ParseState state = ParseState.PARENT_ID_FIELD;

        String parentIdField = null;
        String idField = null;

        AggFunc aggFunc = null;

        for (int i = 0; i < args.length(); i++) {
            char c = args.charAt(i);
            if (c == '\\') {
                escape = true;
                continue;
            }

            if (c == '('
                && !escape
                && state == ParseState.AGGREGATE_FUNC)
            {
                aggFunc =
                    AggFuncFactory
                        .valueOf(args.substring(argStart, i)
                            .toUpperCase(Locale.ROOT))
                        .create(context, args, i + 1);

                state = ParseState.END;
                break;
            } else if (c == ',' && !escape) {
                switch (state) {
                    case PARENT_ID_FIELD:
                        parentIdField = args.substring(argStart, i);
                        state = ParseState.ID_FIELD;
                        break;
                    case ID_FIELD:
                        idField = args.substring(argStart, i);
                        state = ParseState.AGGREGATE_FUNC;
                        break;
                    default:
                        throw new ParseException(
                            "Invalid state " + state + " format is " + FORMAT,
                            i);
                }

                argStart = i + 1;
            }

            escape = false;
        }

        if (state != ParseState.END) {
            throw new ParseException(
                "End of args but state is " + state + " valid format is "
                    + FORMAT,
                args.length());
        }

        try {
            this.maxTreeHeight =
                context.params().getInt(
                    "dp-tree-max-height",
                    MAX_TREE_HEIGHT);
        } catch (BadRequestException bre) {
            throw new ParseException(bre.getMessage(), 0);
        }

        this.parentIdField = StringHelper.intern(parentIdField);
        this.idField = StringHelper.intern(idField);
        this.idFieldIndex =
            context.fieldToIndex().indexFor(this.idField);
        this.aggFunc = aggFunc;

        this.loadFields = new LinkedHashSet<>();
        loadFields.add(this.idField);

        this.aggFunc.loadFields(loadFields);
        if (context.prefix().size() != 1) {
            throw new ParseException(
                "Multi/unprefixed requests no supported by this doc processor",
                0);
        }

        this.prefix = context.prefix().iterator().next();
        this.fieldSelector = new MapFieldSelector(loadFields);
        this.nodesQueue = new LinkedList<>();
        try {
            this.searcher = context.index().getSearcher(prefix, true);
        } catch (IOException e) {
            ParseException pe = new ParseException("Unhandled error", 0);
            pe.initCause(e);
            throw pe;
        }
        this.fieldsCache = context.fieldsCache();
    }

    @Override
    public void after() {
        context.ctx().logger().info("TreeCalcDocProcessor: docsRead: " + reads
            + ", docsFromCache: " + docsFromCache);
        if (searcher != null) {
            try {
                searcher.free();
            } catch (IOException e) {
                context.ctx().logger().log(
                    Level.SEVERE,
                    "Searcher close error",
                    e);
            }
            cacheMap = null;
            searcher = null;
        }
    }

    protected Map<String, String> readDocument(
        final IndexReader reader,
        final int docId,
        final FieldSelector fs)
        throws IOException
    {
        List<CacheInput> caches = null;
        if (fieldsCache != null) {
            if (cacheMap == null) {
                cacheMap = new HashMap<>();
            }
            caches = cacheMap.get(reader);
            if (caches == null) {
                caches = fieldsCache.getCachesFor(reader, loadFields);
                if (caches == null) {
                    caches = EMPTY_CACHE;
                }
                cacheMap.put(reader, caches);
            }
        } else {
            caches = EMPTY_CACHE;
        }
        Map<String, String> doc = null;
        if (caches != EMPTY_CACHE) {
            for (CacheInput oneFieldCache: caches) {
                YaField cachedField;
                if (oneFieldCache.seek(docId)) {
                    cachedField = oneFieldCache.field();
                } else {
                    doc = null;
                    break;
                }
                if (doc == null) {
                    doc = new IdentityHashMap<>();
                }
                doc.put(oneFieldCache.fieldname(), cachedField.toString());
            }
        }
        if (doc == null) {
            doc = new IdentityHashMap<>();
            Document luceneDoc = reader.document(docId, fs);
            for (String field: loadFields) {
                doc.put(field, luceneDoc.get(field));
            }
            reads++;
        } else {
            docsFromCache++;
        }
        return doc;
    }

    @Override
    public void process(final YaDoc3 doc) throws IOException {
        aggFunc.clear();
        nodesQueue.clear();

        YaField field = doc.getField(idFieldIndex);
        if (field == null) {
            return;
        }

        String idValue = field.toString();
        Term term = new Term(parentIdField, idValue);

        nodesQueue.add(idValue);
        IndexSearcher indexSearcher = searcher.searcher();
        IndexReader.AtomicReaderContext[] leaves =
            indexSearcher.getTopReaderContext().leaves();
        if (leaves == null) {
            return;
        }
        leaves = leaves.clone();
        Arrays.sort(
            leaves,
            (x, y) -> Long.compare(x.docBase, y.docBase));

        int nodes = 0;
        long docs = 0;
        while (nodesQueue.size() > 0) {
            idValue = nodesQueue.poll();
            String preparedIdValue = idValue;
            // optimize allocations here, do mutable term query?
            if (context.isPrefixedAnalyzer()) {
                preparedIdValue = StringUtils.concat(
                    context.prefixedAnalyzer().getPrefix(),
                    context.prefixedAnalyzer().getSeparator(),
                    idValue);
            }
            term = term.createTerm(preparedIdValue);
            Query query = new TermQuery(term);
            Weight weight = query.weight(indexSearcher);

            for (IndexReader.AtomicReaderContext context: leaves) {
                DocIdSetIterator docIds =
                    weight.scorer(context, Weight.ScorerContext.def());
                if (docIds != null) {
                    int docId;
                    while (true) {
                        docId = docIds.nextDoc();
                        if (docId == DocIdSetIterator.NO_MORE_DOCS) {
                            break;
                        }

                        //FIXME rewrite on field visitor
                        Map<String, String> document = readDocument(
                            context.reader,
                            docId,
                            fieldSelector);

                        aggFunc.apply(idValue, document);
                        docs += 1;
                        String nextId = document.get(idField);
                        if (nextId != null) {
                            nodesQueue.offer(nextId);
                            if (this.context.debug()) {
                                this.context.ctx().logger().info(
                                    "Node fetched " + nextId);
                            }
                        }
                    }
                }
            }

            if (nodes > maxTreeHeight) {
                throw new IOException(
                    "Too much non leaf node " + nodes);
            }

            nodes += 1;
        }

        aggFunc.flush(doc);
    }

    @Override
    public void apply(final ModuleFieldsAggregator aggregator) {
        aggregator.add(loadFields, aggFunc.outFields());
    }

    private enum AggFuncFactory {
        SUM_COUNT {
            @Override
            AggFunc create(
                final ProcessorRequestContext context,
                final String args,
                final int offset)
                throws ParseException
            {
                return new SumAndCountFunc(context, args, offset);
            }
        },
        PRINT {
            @Override
            AggFunc create(
                final ProcessorRequestContext context,
                final String args,
                final int offset)
                throws ParseException
            {
                return new PrintFunc(context, args, offset);
            }
        };

        abstract AggFunc create(
            final ProcessorRequestContext context,
            final String args,
            final int offset)
            throws ParseException;
    }

    private interface AggFunc {
        void apply(final String parentId, final Map<String, String> value);

        void clear();

        int parsedOffset();

        void flush(final YaDoc3 doc3);

        void loadFields(final Collection<String> fields);

        Set<String> outFields();

        String stringValue();
    }

    private static class PrintFunc implements AggFunc {
        private final LinkedHashMap<String, String> cache;
        private final StringBuilder value;

        private final String idField;
        private final String nameField;
        private final String filterField;
        private final String filterValue;
        private final String outField;
        private final int outFieldIndex;
        private final String separator = "/";
        private final int offset;

        private PrintFunc(
            final ProcessorRequestContext context,
            final String args,
            final int offset)
            throws ParseException
        {
            this.cache = new LinkedHashMap<>();
            this.value = new StringBuilder();
            this.idField = StringHelper.intern("fid");
            this.nameField = StringHelper.intern("name");
            this.filterField = StringHelper.intern("type");
            this.filterValue = "file";
            this.outField = StringHelper.intern("files");
            this.outFieldIndex = context.fieldToIndex().indexFor(this.outField);
            this.offset = offset + 1;
        }

        @Override
        public void apply(final String parent, final Map<String, String> doc) {
            String parentPath = cache.getOrDefault(parent, "");
            String name = doc.get(nameField);

            String path = StringUtils.concat(parentPath, separator, name);
            String filterValue = doc.get(filterField);
            if (this.filterValue.equalsIgnoreCase(filterValue)) {
                value.append(path);
                value.append('\n');
            } else {
                String id = doc.get(idField);
                if (id != null) {
                    cache.put(id, path);
                }
            }
        }

        @Override
        public void clear() {
            cache.clear();
            value.setLength(0);
        }

        @Override
        public int parsedOffset() {
            return offset;
        }

        @Override
        public void flush(final YaDoc3 doc3) {
            doc3.setField(
                outFieldIndex,
                new YaField.StringYaField(
                    StringUtils.getUtf8Bytes(value.toString())));
        }

        @Override
        public void loadFields(final Collection<String> fields) {
            fields.add(this.idField);
            fields.add(this.nameField);
            fields.add(this.filterField);
        }

        @Override
        public Set<String> outFields() {
            return Collections.singleton(outField);
        }

        @Override
        public String stringValue() {
            return value.toString();
        }
    }

    private static class SumAndCountFunc implements AggFunc {
        private long sum = 0;
        private long count = 1;

        private final String field;
        private final String filterField;
        private final String filterValue;
        private final String outFieldSum;
        private final int outFieldSumIndex;
        private final String outFieldCount;
        private final int outFieldCountIndex;
        private final int parsedOffset;
        private final Set<String> outFields;
        private final ProcessorRequestContext context;

        private SumAndCountFunc(
            final ProcessorRequestContext context,
            final String args,
            final int offset)
            throws ParseException
        {
            this.context = context;
            SumCountParseState state = SumCountParseState.GET_FIELD;

            boolean escape = false;
            String getField = null;
            String filterField = null;
            String filterValue = null;
            String outFieldSum = null;
            String outFieldCount = null;

            int i = offset;
            int argStart = i;
            for (; i< args.length(); i++) {
                char c = args.charAt(i);
                if (c == '\\') {
                    escape = true;
                    continue;
                }

                if (c == ')') {
                    outFieldCount = args.substring(argStart, i);
                    state = SumCountParseState.END;
                    break;
                }

                if (c == ' ' && !escape) {
                    filterValue = args.substring(argStart, i);
                    state = SumCountParseState.OUT_FIELD_SUM;
                    argStart = i + 1;
                } else if (c == ',' && !escape) {
                    switch (state) {
                        case GET_FIELD:
                            getField = args.substring(argStart, i);
                            state = SumCountParseState.FILTER_FIELD;
                            break;
                        case FILTER_FIELD:
                            filterField = args.substring(argStart, i);
                            state = SumCountParseState.FILTER_VALUE;
                            break;
                        case OUT_FIELD_SUM:
                            outFieldSum = args.substring(argStart, i);
                            state = SumCountParseState.OUT_FIELD_COUNT;
                            break;
                        default:
                            throw new ParseException(
                                "Invalid state "
                                    + state
                                    + " format is sum_count("
                                    + "get,filter_field,filter_value+"
                                    + "out_field_sum,out_field_cnt)",
                                i);
                    }

                    argStart = i + 1;
                }

                escape = false;
            }

            if (state != SumCountParseState.END) {
                throw new ParseException(
                    "End of args but state is " + state + " valid format is "
                        + FORMAT,
                    args.length());
            }

            this.field = StringHelper.intern(getField);
            context.fieldToIndex().indexFor(this.field);
            this.filterField = StringHelper.intern(filterField);
            context.fieldToIndex().indexFor(this.filterField);
            this.filterValue = StringHelper.intern(filterValue);
            this.outFieldSum = StringHelper.intern(outFieldSum);
            this.outFieldSumIndex =
                context.fieldToIndex().indexFor(this.outFieldSum);
            this.outFieldCount = StringHelper.intern(outFieldCount);
            this.outFieldCountIndex =
                context.fieldToIndex().indexFor(this.outFieldCount);
            this.parsedOffset = i;
            this.outFields = new LinkedHashSet<>(2);
            this.outFields.add(this.outFieldSum);
            this.outFields.add(this.outFieldCount);
        }

        @Override
        public int parsedOffset() {
            return parsedOffset;
        }

        @Override
        public void flush(final YaDoc3 doc3) {
            doc3.setField(outFieldSumIndex, new YaField.LongYaField(sum));
            doc3.setField(outFieldCountIndex, new YaField.LongYaField(count));
        }

        @Override
        public void loadFields(final Collection<String> fields) {
            fields.add(filterField);
            fields.add(field);
        }

        @Override
        public Set<String> outFields() {
            return outFields;
        }

        @Override
        public void apply(final String parent, final Map<String, String> doc) {
            String value = doc.get(filterField);
            if (filterValue.equalsIgnoreCase(value)) {
                value = doc.get(field);
                if (value != null) {
                    sum += Long.parseLong(value);
                    count += 1;
                }
            }
        }

        @Override
        public void clear() {
            sum = 0;
            count = 0;
        }

        @Override
        public String stringValue() {
            return String.valueOf(sum) + "," + count;
        }
    }
}