package ru.yandex.qloud.kikimr.lucene;

import com.google.common.base.Joiner;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
import org.apache.lucene.document.DateTools;
import org.apache.lucene.index.Term;
import org.apache.lucene.queryparser.classic.QueryParser;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.PhraseQuery;
import org.apache.lucene.search.PrefixQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.RegexpQuery;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TermRangeQuery;
import org.apache.lucene.search.WildcardQuery;
import org.springframework.stereotype.Component;
import ru.yandex.qloud.kikimr.search.QueryWhereConverter;
import ru.yandex.qloud.kikimr.utils.Sanitizer;

import java.util.Arrays;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import static java.util.stream.Collectors.toSet;

@Component
public class ESQueryConverter implements QueryWhereConverter {
    private final Map<String, SchemaFieldType> schema;

    private static final Set<String> FIELDS_TO_SEARCH_BY_SUBSTRING = Sets.newHashSet("message", "request", "stackTrace");

    private final Map<Class<? extends Query>, Function<Query, String>> queryClassToConverter =
        ImmutableMap.<Class<? extends Query>, Function<Query, String>>builder()
            .put(BooleanQuery.class, query -> getBooleanClause((BooleanQuery) query))
            .put(TermQuery.class, termQuery -> getTermClause(((TermQuery) termQuery).getTerm()))
            .put(WildcardQuery.class, wildcardQuery -> getWildcardClause(((WildcardQuery) wildcardQuery).getTerm()))
            .put(PrefixQuery.class, prefixQuery -> getPrefixClause(((PrefixQuery) prefixQuery).getPrefix()))
            .put(RegexpQuery.class, query -> getRegexpClause((RegexpQuery) query))
            .put(TermRangeQuery.class, query -> getTermRangeClause((TermRangeQuery) query))
            .put(PhraseQuery.class, query -> getPhraseClause((PhraseQuery) query))
            .build();


    public ESQueryConverter() {
        this(SchemaProvider.getDefaultSchema());
    }

    public ESQueryConverter(Map<String, SchemaFieldType> schema) {
        this.schema = schema;
    }

    @Override
    public String convertToYqlWhereCondition(String esQuery, boolean accessLog) {
        if (esQuery == null || esQuery.isEmpty()) {
            return "TRUE";
        }

        QueryParser parser = new QueryParser(accessLog ? "request" : "message", new SchemaAwareAnalyzer(schema));
        parser.setDateResolution(DateTools.Resolution.MILLISECOND);
        parser.setEnablePositionIncrements(true);
        parser.setLowercaseExpandedTerms(false);

        try {
            return serializeWhereCondition(parser.parse(esQuery));
        } catch (Exception e) {
            throw new ESQueryParserException(esQuery, e);
        }
    }

    private String serializeWhereCondition(Query query) {
        Function<Query, String> conversionFunction = queryClassToConverter.get(query.getClass());
        if (conversionFunction == null) {
            throw new IllegalArgumentException(String.format(
                        "{ UNKNOWN TYPE %s : %s }",
                        query.getClass().toString(),
                        query.toString()
                ));
        }
        return conversionFunction.apply(query);
    }

    private String getBooleanClause(BooleanQuery query) {
        StringBuilder sb = new StringBuilder("(TRUE");
        query.clauses().stream().collect(Collectors.groupingBy(
                BooleanClause::getOccur,
                toSet()
        )).entrySet().stream().map((entry) -> {
            Set<String> value = entry.getValue().stream().map(c -> serializeWhereCondition(c.getQuery())).collect(toSet());
            switch (entry.getKey()) {
                case MUST:
                case FILTER:
                    return Joiner.on(" AND ").join(value);
                case SHOULD:
                    return Joiner.on(" OR ").join(value);
                case MUST_NOT:
                    return String.format("NOT (%s)", Joiner.on(" OR ").join(value));
            }
            throw new IllegalStateException("");
        }).collect(Collectors.toList()).forEach(s -> {
            sb.append(" AND (").append(s).append(")");
        });
        return sb.append(")").toString();
    }

    private String getTermRangeClause(TermRangeQuery query) {
        StringBuilder sb = new StringBuilder("(TRUE");
        String fieldName = Sanitizer.sanitizeFieldName(query.getField());
        boolean fromJson = Strings.nullToEmpty(fieldName).contains("Json::GetField");
        if (query.getLowerTerm() != null) {
            sb.append(String.format(
                    " AND (%s >" + (query.includesLower() ? "=" : "") + " %s)",
                    Sanitizer.sanitizeFieldName(query.getField()),
                    getTextValueMayBeFromJson(query.getLowerTerm().utf8ToString(), fromJson)
            ));
        }
        if (query.getUpperTerm() != null) {
            sb.append(String.format(
                    " AND (%s <" + (query.includesUpper() ? "=" : "") + " %s)",
                    Sanitizer.sanitizeFieldName(query.getField()),
                    getTextValueMayBeFromJson(query.getUpperTerm().utf8ToString(), fromJson)
            ));
        }
        return sb.append(")").toString();
    }

    private String getTextValueMayBeFromJson(String value, boolean fromJson) {
        String quote = fromJson ? "'" : "";
        return quote + Sanitizer.sanitizeTextValue(value) + quote;
    }

    private String getPhraseClause(PhraseQuery query) {
        return String.format(
                "(%s)",
                Joiner.on(" OR ").join(Arrays.stream(query.getTerms()).map(this::getTermClause).iterator())
        );
    }

    private String getRegexpClause(RegexpQuery query) {
        return String.format(
                "(%s REGEXP '%s')",
                Sanitizer.sanitizeFieldName(query.getField()),
                Sanitizer.sanitizeTextValue(query.getRegexp().text())
        );
    }

    private String getTermClause(Term t) {
        String fieldName = Sanitizer.sanitizeFieldName(t.field());
        SchemaFieldType type = schema.getOrDefault(fieldName, SchemaFieldType.UTF8);
        switch (type) {
            case UTF8:
                return String.format(
                        "(%s " + (FIELDS_TO_SEARCH_BY_SUBSTRING.contains(fieldName) ? "ILIKE '%%%s%%" : "LIKE '%s") + "')",
                        fieldName, Sanitizer.sanitizeTextValue(t.text())
                );
            case INT64: case INT32: case FLOAT:
                return String.format("(%s = %s)", fieldName, t.text());
            default:
                throw new RuntimeException(String.format("Unknown field type %s for field %s", type, fieldName));
        }
    }

    private String getWildcardClause(Term t) {
        String baseClause = getTermClause(t);
        return baseClause.replaceAll("\\.\\*", "%").replaceAll("(?<!\\\\)\\*", "%");
    }

    private String getPrefixClause(Term prefix) {
        return String.format("(%s ILIKE '%s%%')", Sanitizer.sanitizeFieldName(prefix.field()), Sanitizer
                .sanitizeTextValue(prefix.text()));
    }
}
