package ru.yandex.msearch.collector.docprocessor;

import java.io.IOException;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;

import org.apache.lucene.document.FieldSelector;
import org.apache.lucene.document.MapFieldSelector;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.StringHelper;
import ru.yandex.http.util.BadRequestException;
import ru.yandex.msearch.ProcessorRequestContext;
import ru.yandex.msearch.collector.YaDoc3;
import ru.yandex.msearch.collector.YaField;
import ru.yandex.parser.string.CollectionParser;
import ru.yandex.parser.string.NonEmptyValidator;
import ru.yandex.util.string.StringUtils;

/**
 * Resolves tree structure. The best way is small file system
 * where we trying to get to root folder or resolve full path.
 * Special disk version
 */
public class DiskTreeRootDocProcessor
    extends AbstractJoinDocProcessor
    implements DocProcessor
{
    private static final YaField ONE = new YaField.IntegerYaField(1);

    private static final int MAX_TREE_HEIGHT = 1000;
    private static final int DEFAULT_CACHE_SIZE = 50000;
    private static final String FORMAT =
        "tree_root_rev(parent_id_field,id_field,extra_condition,"
            + "aggr_func,get_field+out_field)";

    private final String parentIdField;
    private final int parentIdFieldIndex;
    private final String idField;
    private final String extraCondition;
    private final AggFunc aggFunc;
    private final FieldSelector fieldSelector;
    private final boolean keepOnNull;
    private long totalTime = 0;

    private final Set<String> loadFields;

    private Query extraQuery;

    private enum ParseState {
        PARENT_ID_FIELD,
        ID_FIELD,
        EXTRA_CONDITION,
        AGGREGATE_FUNC,
        AGGREGATE_FUNC_ARG,
        GET_FIELD,
        OUT_FIELD
    }

    /**
     * tree_root(parent_id_field,id_field,extra_condition,aggr_func,get_field+out_field)
     * agg_func concat,sum
     * @param args
     * @param context
     * @throws ParseException
     */
    public DiskTreeRootDocProcessor(
        final String args,
        final ProcessorRequestContext context)
        throws ParseException
    {
        super(context);

        int argStart = 0;
        boolean escape = false;

        ParseState state = ParseState.PARENT_ID_FIELD;

        String parentIdField = null;
        String idField = null;
        String extraCondition = null;
        List<String> getFields = new ArrayList<>();
        List<String> outFields = new ArrayList<>();
        String aggFunc = null;
        String aggFuncArgs = null;
        for (int i = 0; i < args.length(); i++) {
            char c = args.charAt(i);
            if (c == '\\') {
                escape = true;
                continue;
            }

            if (state == ParseState.AGGREGATE_FUNC && c == '(' && !escape) {
                aggFunc = args.substring(argStart, i);
                argStart = i + 1;
                state = ParseState.AGGREGATE_FUNC_ARG;
            }

            if (c == ' '
                && !escape
                && state == ParseState.GET_FIELD)
            {
                getFields.add(StringHelper.intern(args.substring(argStart, i)));

                state = ParseState.OUT_FIELD;
                argStart = i + 1;
            } else if (c == ',' && !escape) {
                switch (state) {
                    case PARENT_ID_FIELD:
                        parentIdField = args.substring(argStart, i);
                        state = ParseState.ID_FIELD;
                        break;
                    case ID_FIELD:
                        idField = args.substring(argStart, i);
                        state = ParseState.EXTRA_CONDITION;
                        break;
                    case EXTRA_CONDITION:
                        extraCondition = args.substring(argStart, i);
                        state = ParseState.AGGREGATE_FUNC;
                        break;
                    case AGGREGATE_FUNC:
                        aggFunc = args.substring(argStart, i);
                        state = ParseState.GET_FIELD;
                        break;
                    case GET_FIELD:
                        getFields.add(
                            StringHelper.intern(args.substring(argStart, i)));
                        argStart = i + 1;
                        break;
                    case OUT_FIELD:
                        outFields.add(
                            StringHelper.intern(args.substring(argStart, i)));
                        argStart = i + 1;
                        break;
                    case AGGREGATE_FUNC_ARG:
                        aggFuncArgs = args.substring(argStart, i - 1);
                        state = ParseState.GET_FIELD;
                        break;
                    default:
                        throw new ParseException(
                            "Invalid state " + state + " format is " + FORMAT,
                            i);
                }

                argStart = i + 1;
            }

            escape = false;
        }

        if (state != ParseState.OUT_FIELD) {
            throw new ParseException(
                "End of args but state is " + state + " valid format is "
                    + FORMAT,
                args.length());
        }

        outFields.add(StringHelper.intern(args.substring(
            argStart,
            args.length())));
        //places 1 in that field if dp returned null
        this.parentIdField = StringHelper.intern(parentIdField);
        this.parentIdFieldIndex =
            context.fieldToIndex().indexFor(this.parentIdField);
        this.idField = StringHelper.intern(idField);
        this.extraCondition = extraCondition;
        this.aggFunc =
            AggFuncFactory.valueOf(aggFunc.toUpperCase(Locale.ROOT))
                .create(aggFuncArgs, getFields, outFields, context);
        this.loadFields = new LinkedHashSet<>();
        loadFields.add(this.parentIdField);

        for (int i = 0; i < this.aggFunc.getFields().length; i++) {
            loadFields.add(this.aggFunc.getFields()[i]);
        }

        this.fieldSelector = new MapFieldSelector(loadFields);
        try {
            this.keepOnNull =
                context.params().getBoolean(
                    "tree-root-keep-null",
                    false);
        } catch (BadRequestException bre) {
            throw new ParseException(bre.toString(), 0);
        }
    }

    @Override
    public Set<String> loadFields() {
        return loadFields;
    }

    @Override
    public void after() {
        super.after();
        StringBuilder sb = new StringBuilder("TreeRoot totalProcessTime: ");
        sb.append(totalTime);
        sb.append(" SubQueries: ");
        sb.append(subQueries);
        context.ctx().logger().info(sb.toString());
    }

    @Override
    public boolean processWithFilter(final YaDoc3 doc) throws IOException {
        boolean keepDoc = true;
        long start = System.currentTimeMillis();
        try {
            if (!extract(doc)) {
                keepDoc = keepOnNull;
            }
        } catch (org.apache.lucene.queryParser.ParseException pe) {
            throw new IOException("Bad query " + extraCondition, pe);
        }

        totalTime += System.currentTimeMillis() - start;
        return keepDoc;
    }

    protected boolean extract(final YaDoc3 doc)
        throws IOException,
        org.apache.lucene.queryParser.ParseException
    {
        aggFunc.clear();

        YaField field = doc.getField(parentIdFieldIndex);
        if (field == null) {
            if (context.debug()) {
                context.ctx().logger().info(
                    "Empty parent id for doc " + doc);
            }

            return false;
        }

        int[] fieldsIndexes = aggFunc.getFieldsIndexes();

        String parentId = field.toString();
        for (int i = 0; i < fieldsIndexes.length; i++) {
            aggFunc.leaf(
                parentId,
                i,
                doc.getField(aggFunc.getFieldsIndexes()[i]));
        }

        if (context.debug()) {
            context.ctx().logger().info(
                "Processing " + doc + " " + parentId);
        }

        if (!aggFunc.checkContinue(parentId)) {
            return aggFunc.get(doc);
        }

        if (!extraCondition.isEmpty()) {
            if (extraQuery == null) {
                extraQuery =
                    context.queryParser().parse(extraCondition);
            }
        }

        String stepParentId = parentId;
        for (int i = 0; i < MAX_TREE_HEIGHT; i++) {
            Query joinQuery = buildJoinQuery(stepParentId, idField);
            if (extraQuery != null) {
                BooleanQuery bq = new BooleanQuery();
                bq.add(joinQuery, BooleanClause.Occur.MUST);
                bq.add(extraQuery, BooleanClause.Occur.MUST);
                joinQuery = bq;
            }

            Map<String, YaField> document =
                extractDoc(joinQuery, fieldSelector);


            for (int j = fieldsIndexes.length - 1; j >= 0; j--) {
                YaField yaField = document.get(aggFunc.getFields()[j]);

                aggFunc.apply(stepParentId, j, yaField);
            }

            YaField yaField = document.get(parentIdField);
            if (yaField != null) {
                stepParentId = yaField.toString();
            } else {
                stepParentId = null;
            }

            if (context.debug()) {
                StringBuilder logStr = new StringBuilder();
                if (stepParentId == null) {
                    logStr.append("Parent fid not found");
                } else {
                    logStr.append("Extracted value ");
                    logStr.append(document.get(aggFunc.getFields()[0]));
                    logStr.append(" new parent ");
                    logStr.append(document.get(parentIdField));
                }

                logStr.append(" JoinQuery ");
                logStr.append(joinQuery.toString());
                context.ctx().logger().info(logStr.toString());
            }

            if (stepParentId == null || !aggFunc.checkContinue(stepParentId)) {
                break;
            }
        }

        return aggFunc.get(doc);
    }

    @Override
    public void apply(final ModuleFieldsAggregator aggregator) {
        aggregator.add(loadFields, aggFunc.outFields());
    }

    private enum AggFuncFactory {
        DISK_CONCAT {
            @Override
            AggFunc create(
                final String separator,
                final List<String> getFields,
                final List<String> outFields,
                final ProcessorRequestContext context)
                throws ParseException
            {
                return new DiskFunc(separator, getFields, outFields, context);
            }
        };

        abstract AggFunc create(
            final String separator,
            final List<String> getFields,
            final List<String> outFields,
            final ProcessorRequestContext context)
            throws ParseException;
    }

    private interface AggFunc {
        void apply(
            final String id,
            final int index,
            final YaField field);

        default void leaf(
            final String id,
            final int index,
            final YaField field)
        {
            apply(id, index, field);
        }

        String[] getFields();

        Set<String> outFields();

        int[] getFieldsIndexes();

        void clear();

        boolean get(final YaDoc3 doc);

        boolean checkContinue(final String parentId);
    }

    private static class DiskFunc implements AggFunc {
        private static final CollectionParser<String, Set<String>, Exception>
            ALLOWED_ROOTS_PARSER =
            new CollectionParser<>(
                NonEmptyValidator.TRIMMED, LinkedHashSet::new);

        protected final DocProcessorQueryCache<DiskCacheItem> cache;

        protected final StringBuilder sb;
        protected final String separator;
        protected final List<Long> maxRevisions;
        protected final List<String> values;
        protected final List<String> ids;
        protected final String[] getFields;
        protected final Set<String> outFields;
        protected final int[] getFieldsIndexes;
        protected final int[] outFieldsIndexes;
        protected final ProcessorRequestContext context;

        private final Set<String> allowedRoots;
        private boolean badRoot = false;
        private boolean cacheHit = false;
        private long maxRevision = Long.MIN_VALUE;

        public DiskFunc(
            final String separator,
            final List<String> getFields,
            final List<String> outFields,
            final ProcessorRequestContext context)
            throws ParseException
        {
            this.outFields = new LinkedHashSet<>(outFields);
            this.getFields = new String[getFields.size()];
            this.getFieldsIndexes = new int[getFields.size()];
            for (int i = 0; i < getFields.size(); i++) {
                String field = getFields.get(i);
                this.getFields[i] = field;
                this.getFieldsIndexes[i] =
                    context.fieldToIndex().indexFor(field);
            }

            this.outFieldsIndexes = new int[outFields.size()];
            for (int i = 0; i < outFields.size(); i++) {
                String field = outFields.get(i);
                this.outFieldsIndexes[i] =
                    context.fieldToIndex().indexFor(field);
            }

            this.separator = separator;
            this.context = context;
            this.sb = new StringBuilder();
            this.ids = new ArrayList<>(MAX_TREE_HEIGHT);
            this.values = new ArrayList<>(MAX_TREE_HEIGHT);
            this.maxRevisions = new ArrayList<>(MAX_TREE_HEIGHT);

            try {
                int maxCacheSize =
                    context.params().getInt(
                        "join-cache-size",
                        DEFAULT_CACHE_SIZE);

                this.cache = new DocProcessorQueryCache<>(maxCacheSize);

                this.allowedRoots =
                    context.params().get(
                        "disk-allowed-roots",
                        Collections.emptySet(),
                        ALLOWED_ROOTS_PARSER);
            } catch (BadRequestException bre) {
                throw new ParseException(bre.toString(), 0);
            }
        }

        @Override
        public Set<String> outFields() {
            return outFields;
        }

        @Override
        public String[] getFields() {
            return getFields;
        }

        @Override
        public int[] getFieldsIndexes() {
            return getFieldsIndexes;
        }

        @Override
        public void apply(
            final String id,
            final int index,
            final YaField field)
        {
            if (index == 0) {
                if (field == null) {
                    badRoot = true;
                } else {
                    ids.add(id);
                    values.add(field.toString());
                }
            } else {
                if (field != null) {
                    maxRevisions.add(field.longValue());
                } else {
                    maxRevisions.add(Long.MIN_VALUE);
                }
            }
        }

        @Override
        public void clear() {
            sb.setLength(0);
            ids.clear();
            values.clear();
            maxRevisions.clear();
            this.badRoot = false;
            this.cacheHit = false;
            this.maxRevision = Long.MIN_VALUE;
        }

        @Override
        public boolean checkContinue(final String parentId) {
            if (badRoot) {
                return false;
            }

            DiskCacheItem cached = cache.get(parentId);
            if (cached != null) {
                cacheHit = true;
                if (cached == DiskCacheItem.NULL) {
                    badRoot = true;
                } else {
                    sb.append(cached.value);
                    if (cached.maxRevision > maxRevision) {
                        maxRevision = cached.maxRevision;
                    }
                }

                return false;
            }

            return true;
        }

        @Override
        public boolean get(final YaDoc3 doc) {
            if (!cacheHit) {
                String root = values.get(values.size() - 1);
                if (!allowedRoots.contains(root)) {
                    badRoot = true;
                    if (context.debug()) {
                        context.ctx().logger().info(
                            "BadRoot " + root);
                    }
                }
            }

            if (badRoot) {
                for (int i = ids.size() - 1; i >= 1; i--) {
                    cache.put(ids.get(i), DiskCacheItem.NULL);
                }

                return false;
            }

            for (int i = ids.size() - 1; i >= 1; i--) {
                if (maxRevisions.get(i) > maxRevision) {
                    maxRevision = maxRevisions.get(i);
                }

                sb.append(separator);
                sb.append(values.get(i));
                DiskCacheItem cached =
                    new DiskCacheItem(sb.toString(), maxRevision);

                cache.put(ids.get(i), cached);
                if (context.debug()) {
                    context.ctx().logger().info(
                        "Putted in cache " + ids.get(i) + " " + cached);
                }
            }

            if (maxRevisions.get(0) > maxRevision) {
                maxRevision = maxRevisions.get(0);
            }

            sb.append(separator);
            sb.append(values.get(0));

            if (context.debug()) {
                context.ctx().logger().info(
                    ids.toString());
                context.ctx().logger().info(
                    values.toString());
                context.ctx().logger().info(
                    "Result " + sb.toString() + " " + maxRevision);
            }

            doc.setField(
                outFieldsIndexes[0],
                new YaField.StringYaField(
                    StringUtils.getUtf8Bytes(sb.toString())));
            doc.setField(
                outFieldsIndexes[1],
                new YaField.LongYaField(maxRevision));
            return true;
        }
    }

    private static class DiskCacheItem {
        private static DiskCacheItem NULL = new DiskCacheItem(null, -1L);

        private final String value;
        private final long maxRevision;

        public DiskCacheItem(final String value, final long maxRevision) {
            this.value = value;
            this.maxRevision = maxRevision;
        }

        @Override
        public String toString() {
            return "DiskCacheItem{" +
                "value='" + value + '\'' +
                ", maxRevision=" + maxRevision +
                '}';
        }
    }
}
