package ru.yandex.msearch.collector.docprocessor;

import java.io.IOException;
import java.text.ParseException;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import java.util.Locale;
import java.util.Map;
import java.util.Set;

import org.apache.lucene.document.Document;
import org.apache.lucene.document.FieldSelector;
import org.apache.lucene.document.MapFieldSelector;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.StringHelper;
import ru.yandex.http.util.BadRequestException;
import ru.yandex.msearch.ProcessorRequestContext;
import ru.yandex.msearch.collector.YaDoc3;
import ru.yandex.msearch.collector.YaField;
import ru.yandex.parser.string.CollectionParser;
import ru.yandex.parser.string.NonEmptyValidator;
import ru.yandex.util.string.StringUtils;

/**
 * Resolves tree structure. The best way is small file system
 * where we trying to get to root folder or resolve full path.
 */
public class TreeRootDocProcessor
    extends AbstractJoinDocProcessor
    implements DocProcessor
{
    private static final YaField ONE = new YaField.IntegerYaField(1);

    private static final int MAX_TREE_HEIGHT = 1000;
    private static final int DEFAULT_CACHE_SIZE = 50000;
    private static final String FORMAT =
        "tree_root(parent_id_field,id_field,extra_condition,"
            + "aggr_func,get_field+out_field)";

    private final String parentIdField;
    private final int parentIdFieldIndex;
    private final String idField;
    private final String extraCondition;
    private final String getField;
    private final int getFieldIndex;
    private final String outField;
    private final int outFieldIndex;
    private final AggFunc aggFunc;
    private final FieldSelector fieldSelector;
    private final boolean keepOnNull;
    private long totalTime = 0;

    private final Set<String> loadFields;

    private Query extraQuery;

    private enum ParseState {
        PARENT_ID_FIELD,
        ID_FIELD,
        EXTRA_CONDITION,
        AGGREGATE_FUNC,
        AGGREGATE_FUNC_ARG,
        GET_FIELD,
        OUT_FIELD
    }

    /**
     * tree_root(parent_id_field,id_field,extra_condition,aggr_func,get_field+out_field)
     * agg_func concat,sum
     * @param args
     * @param context
     * @throws ParseException
     */
    public TreeRootDocProcessor(
        final String args,
        final ProcessorRequestContext context)
        throws ParseException
    {
        super(context);

        int argStart = 0;
        boolean escape = false;

        ParseState state = ParseState.PARENT_ID_FIELD;

        String parentIdField = null;
        String idField = null;
        String extraCondition = null;
        String getField = null;
        String aggFunc = null;
        String aggFuncArgs = null;
        for (int i = 0; i < args.length(); i++) {
            char c = args.charAt(i);
            if (c == '\\') {
                escape = true;
                continue;
            }

            if (state == ParseState.AGGREGATE_FUNC && c == '(' && !escape) {
                aggFunc = args.substring(argStart, i);
                argStart = i + 1;
                state = ParseState.AGGREGATE_FUNC_ARG;
            }

            if (c == ' '
                && !escape
                && state == ParseState.GET_FIELD)
            {
                getField = StringHelper.intern(args.substring(argStart, i));

                state = ParseState.OUT_FIELD;
                argStart = i + 1;
            } else if (c == ',' && !escape) {
                switch (state) {
                    case PARENT_ID_FIELD:
                        parentIdField = args.substring(argStart, i);
                        state = ParseState.ID_FIELD;
                        break;
                    case ID_FIELD:
                        idField = args.substring(argStart, i);
                        state = ParseState.EXTRA_CONDITION;
                        break;
                    case EXTRA_CONDITION:
                        extraCondition = args.substring(argStart, i);
                        state = ParseState.AGGREGATE_FUNC;
                        break;
                    case AGGREGATE_FUNC:
                        aggFunc = args.substring(argStart, i);
                        state = ParseState.GET_FIELD;
                        break;
                    case AGGREGATE_FUNC_ARG:
                        aggFuncArgs = args.substring(argStart, i - 1);
                        state = ParseState.GET_FIELD;
                        break;
                    default:
                        throw new ParseException(
                            "Invalid state " + state + " format is " + FORMAT,
                            i);
                }

                argStart = i + 1;
            }

            escape = false;
        }

        if (state != ParseState.OUT_FIELD) {
            throw new ParseException(
                "End of args but state is " + state + " valid format is "
                    + FORMAT,
                args.length());
        }

        this.outField =
            StringHelper.intern(args.substring(argStart, args.length()));
        this.outFieldIndex = context.fieldToIndex().indexFor(outField);
        //places 1 in that field if dp returned null
        this.parentIdField = StringHelper.intern(parentIdField);
        this.parentIdFieldIndex =
            context.fieldToIndex().indexFor(this.parentIdField);
        this.idField = StringHelper.intern(idField);
        this.extraCondition = extraCondition;
        this.getField = StringHelper.intern(getField);
        this.getFieldIndex = context.fieldToIndex().indexFor(this.getField);
        this.aggFunc =
            AggFuncFactory.valueOf(aggFunc.toUpperCase(Locale.ROOT))
                .create(aggFuncArgs, context);
        this.loadFields = new LinkedHashSet<>();
        loadFields.add(this.parentIdField);
        loadFields.add(this.getField);

        this.fieldSelector = new MapFieldSelector(loadFields);
        try {
            this.keepOnNull =
                context.params().getBoolean(
                    "tree-root-keep-null",
                    false);
        } catch (BadRequestException bre) {
            throw new ParseException(bre.toString(), 0);
        }
    }

    @Override
    public Set<String> loadFields() {
        return loadFields;
    }

    @Override
    public void after() {
        super.after();
        StringBuilder sb = new StringBuilder("TreeRoot totalProcessTime: ");
        sb.append(totalTime);
        sb.append(" SubQueries: ");
        sb.append(subQueries);
        context.ctx().logger().info(sb.toString());
    }

    @Override
    public boolean processWithFilter(final YaDoc3 doc) throws IOException {
        boolean keepDoc = true;
        long start = System.currentTimeMillis();
        try {
            YaField result = extract(doc);
            if (result != null) {
                doc.setField(outFieldIndex, result);
            } else {
                keepDoc = keepOnNull;
            }
        } catch (org.apache.lucene.queryParser.ParseException pe) {
            throw new IOException("Bad query " + extraCondition, pe);
        }

        totalTime += System.currentTimeMillis() - start;
        return keepDoc;
    }

    protected YaField extract(final YaDoc3 doc)
        throws IOException,
        org.apache.lucene.queryParser.ParseException
    {
        aggFunc.clear();

        YaField field = doc.getField(parentIdFieldIndex);
        if (field == null) {
            if (context.debug()) {
                context.ctx().logger().info(
                    "Empty parent id for doc " + doc);
            }
            return null;
        }

        String parentId = field.toString();
        YaField selfValueField = doc.getField(getFieldIndex);
        String selfValue = selfValueField.toString();

        if (context.debug()) {
            context.ctx().logger().info(
                "Processing " + selfValueField + " " + parentId);
        }

        aggFunc.apply(parentId, selfValue);
        if (!aggFunc.checkContinue(parentId)) {
            return aggFunc.get();
        }

        if (!extraCondition.isEmpty()) {
            if (extraQuery == null) {
                extraQuery =
                    context.queryParser().parse(extraCondition);
            }
        }

        String stepParentId = parentId;
        for (int i = 0; i < MAX_TREE_HEIGHT; i++) {
            Query joinQuery = buildJoinQuery(stepParentId, idField);
            if (extraQuery != null) {
                BooleanQuery bq = new BooleanQuery();
                bq.add(joinQuery, BooleanClause.Occur.MUST);
                bq.add(extraQuery, BooleanClause.Occur.MUST);
                joinQuery = bq;
            }

            Map<String, YaField> document =
                extractDoc(joinQuery, fieldSelector);
            YaField yaField = document.get(getField);
            if (yaField == null) {
                break;
            }
            String getValue = yaField.toString();
            aggFunc.apply(stepParentId, getValue);

            yaField = document.get(parentIdField);
            if (yaField != null) {
                stepParentId = yaField.toString();
            } else {
                stepParentId = null;
            }

            if (context.debug()) {
                StringBuilder logStr = new StringBuilder();
                if (stepParentId == null) {
                    logStr.append("Parent fid not found");
                } else {
                    logStr.append("Extracted value ");
                    logStr.append(document.get(getField));
                    logStr.append(" new parent ");
                    logStr.append(document.get(parentIdField));
                }

                logStr.append(" JoinQuery ");
                logStr.append(joinQuery.toString());
                context.ctx().logger().info(logStr.toString());
            }

            if (stepParentId == null || !aggFunc.checkContinue(stepParentId)) {
                break;
            }
        }

        return aggFunc.get();
    }

    @Override
    public void apply(final ModuleFieldsAggregator aggregator) {
        aggregator.addWithSingleOut(outField, loadFields);
    }

    private enum AggFuncFactory {
        CONCAT {
            @Override
            AggFunc create(
                final String args,
                final ProcessorRequestContext context)
                throws ParseException
            {
                return new ConcatFunc(args, context);
            }
        },
        DISK_CONCAT {
            @Override
            AggFunc create(
                final String args,
                final ProcessorRequestContext context)
                throws ParseException
            {
                return new DiskConcatFunc(args, context);
            }
        };

        abstract AggFunc create(
            final String args,
            final ProcessorRequestContext context)
            throws ParseException;
    }

    private interface AggFunc {
        void apply(final String id, final String value);

        void clear();

        YaField get();

        boolean checkContinue(final String parentId);
    }

    private abstract static class AbstractConcatFunc implements AggFunc {
        protected final DocProcessorQueryCache<String> cache;

        protected final StringBuilder sb;
        protected final String separator;
        protected final List<String> values;
        protected final List<String> ids;
        protected final ProcessorRequestContext context;

        public AbstractConcatFunc(
            final String separator,
            final ProcessorRequestContext context)
            throws ParseException
        {
            this.separator = separator;
            this.context = context;
            this.sb = new StringBuilder();
            this.ids = new ArrayList<>(MAX_TREE_HEIGHT);
            this.values = new ArrayList<>(MAX_TREE_HEIGHT);

            int maxCacheSize;
            try {
                maxCacheSize =
                    context.params().getInt(
                        "join-cache-size",
                        DEFAULT_CACHE_SIZE);
            } catch (BadRequestException bre) {
                throw new ParseException(bre.toString(), 0);
            }
            this.cache = new DocProcessorQueryCache<>(maxCacheSize);
        }

        protected abstract void cache(final String id, final String value);

        @Override
        public void apply(final String id, String value) {
            ids.add(id);
            values.add(value);
        }

        @Override
        public void clear() {
            sb.setLength(0);
            ids.clear();
            values.clear();
        }

        @Override
        public YaField get() {
            for (int i = ids.size() - 1; i >= 1; i--) {
                sb.append(separator);
                sb.append(values.get(i));
                String cached = sb.toString();
                cache(ids.get(i), cached);
                if (context.debug()) {
                    context.ctx().logger().info(
                        "Putted in cache " + ids.get(i) + " " + cached);
                }
            }

            sb.append(separator);
            sb.append(values.get(0));

            if (context.debug()) {
                context.ctx().logger().info(
                    ids.toString());
                context.ctx().logger().info(
                    values.toString());
                context.ctx().logger().info("Result " + sb.toString());
            }

            return new YaField.StringYaField(
                StringUtils.getUtf8Bytes(sb.toString()));
        }
    }

    private static class ConcatFunc extends AbstractConcatFunc {
        private boolean cacheHit = false;

        public ConcatFunc(
            final String separator,
            final ProcessorRequestContext context) throws ParseException
        {
            super(separator, context);
        }

        @Override
        public boolean checkContinue(final String parentId) {
            String cached = cache.get(parentId);
            if (cached != null) {
                if (context.debug()) {
                    context.ctx().logger().info(
                        "Cache hit for " + parentId + " " + cached);
                }
                cacheHit = true;
                sb.append(cached);

                return false;
            } else {
                if (context.debug()) {
                    context.ctx().logger().info(
                        "Cache miss for " + parentId);
                }
            }

            return true;
        }

        @Override
        protected void cache(final String id, final String value) {
            cache.put(id, value);
        }
    }

    private static class DiskConcatFunc extends AbstractConcatFunc {
        private static final CollectionParser<String, Set<String>, Exception>
            ALLOWED_ROOTS_PARSER =
            new CollectionParser<>(
                NonEmptyValidator.TRIMMED, LinkedHashSet::new);
        private static final String BAD_ROOT = new String("null");

        private final Set<String> allowedRoots;
        private boolean badRoot = false;
        private boolean cacheHit = false;

        public DiskConcatFunc(
            final String separator,
            final ProcessorRequestContext context)
            throws ParseException
        {
            super(separator, context);

            try {
                this.allowedRoots =
                    context.params().get(
                        "disk-allowed-roots",
                        Collections.emptySet(),
                        ALLOWED_ROOTS_PARSER);
            } catch (BadRequestException bre) {
                throw new ParseException(bre.toString(), 0);
            }
        }

        @Override
        public boolean checkContinue(final String parentId) {
            String cached = cache.get(parentId);
            if (cached != null) {
                cacheHit = true;
                if (cached == BAD_ROOT) {
                    badRoot = true;
                } else {
                    sb.append(cached);
                }

                return false;
            }

            return true;
        }

        @Override
        public void clear() {
            super.clear();
            this.badRoot = false;
            this.cacheHit = false;
        }

        @Override
        protected void cache(final String id, final String value) {
            if (badRoot) {
                cache.put(id, BAD_ROOT);
            } else {
                cache.put(id, value);
            }
        }

        @Override
        public YaField get() {
            if (!cacheHit) {
                String root = values.get(values.size() - 1);
                if (!allowedRoots.contains(root)) {
                    badRoot = true;
                    if (context.debug()) {
                        context.ctx().logger().info(
                            "BadRoot " + root);
                    }
                }
            }

            if (badRoot) {
                for (int i = ids.size() - 1; i >= 1; i--) {
                    cache(ids.get(i), BAD_ROOT);
                }

                return null;
            }

            return super.get();
        }

        public boolean cacheHit() {
            return cacheHit;
        }
    }
}
