package ru.yandex.solomon.expression;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;

import javax.annotation.ParametersAreNonnullByDefault;

import com.google.common.collect.ImmutableList;

import ru.yandex.solomon.expression.SelLexer.ScannedToken;
import ru.yandex.solomon.expression.SelLexer.Token;
import ru.yandex.solomon.expression.SelLexer.TokenType;
import ru.yandex.solomon.expression.analytics.ProgramWithReturn;
import ru.yandex.solomon.expression.ast.Ast;
import ru.yandex.solomon.expression.ast.AstAnonymous;
import ru.yandex.solomon.expression.ast.AstAssignment;
import ru.yandex.solomon.expression.ast.AstBinOp;
import ru.yandex.solomon.expression.ast.AstCall;
import ru.yandex.solomon.expression.ast.AstIdent;
import ru.yandex.solomon.expression.ast.AstInterpolatedString;
import ru.yandex.solomon.expression.ast.AstLambda;
import ru.yandex.solomon.expression.ast.AstObject;
import ru.yandex.solomon.expression.ast.AstOp;
import ru.yandex.solomon.expression.ast.AstSelector;
import ru.yandex.solomon.expression.ast.AstSelectors;
import ru.yandex.solomon.expression.ast.AstStatement;
import ru.yandex.solomon.expression.ast.AstTernaryOp;
import ru.yandex.solomon.expression.ast.AstUnaryOp;
import ru.yandex.solomon.expression.ast.AstUse;
import ru.yandex.solomon.expression.ast.AstValue;
import ru.yandex.solomon.expression.ast.AstValueDouble;
import ru.yandex.solomon.expression.ast.AstValueDuration;
import ru.yandex.solomon.expression.ast.AstValueString;
import ru.yandex.solomon.expression.exceptions.ParserException;
import ru.yandex.solomon.expression.parser.DurationParser;
import ru.yandex.solomon.labels.InterpolatedString;
import ru.yandex.solomon.labels.LabelValues;
import ru.yandex.solomon.labels.query.SelectorType;

import static ru.yandex.solomon.expression.SelLexer.TokenType.DOUBLE;
import static ru.yandex.solomon.expression.SelLexer.TokenType.DURATION;
import static ru.yandex.solomon.expression.SelLexer.TokenType.EOF;
import static ru.yandex.solomon.expression.SelLexer.TokenType.IDENT;
import static ru.yandex.solomon.expression.SelLexer.TokenType.IDENT_WITH_DOTS;
import static ru.yandex.solomon.expression.SelLexer.TokenType.KEYWORD;
import static ru.yandex.solomon.expression.SelLexer.TokenType.PUNCT;
import static ru.yandex.solomon.expression.SelLexer.TokenType.STRING;

/**
 * @author Stepan Koltsov
 */
@ParametersAreNonnullByDefault
public class SelParser {
    private final SelLexer lexer;
    private final boolean astRangesEnabled;

    private static final Token OPENING_BRACE = PUNCT.of("{");
    private static final Token CLOSING_BRACE = PUNCT.of("}");

    private static final Token OPENING_PAREN = PUNCT.of("(");
    private static final Token COMMA = PUNCT.of(",");
    private static final Token CLOSING_PAREN = PUNCT.of(")");

    private static final Token OPENING_BRACKET = PUNCT.of("[");
    private static final Token CLOSING_BRACKET = PUNCT.of("]");

    private static final Token ARROW = PUNCT.of("->");

    private static final Token PLUS = PUNCT.of("+");
    private static final Token MINUS = PUNCT.of("-");

    private static final Token DIV = PUNCT.of("/");
    private static final Token MUL = PUNCT.of("*");

    private static final Token NOT = PUNCT.of("!");

    private static final Token AND = PUNCT.of("&&");
    private static final Token OR = PUNCT.of("||");

    private static final Token LT = PUNCT.of("<");
    private static final Token GT = PUNCT.of(">");
    private static final Token LE = PUNCT.of("<=");
    private static final Token GE = PUNCT.of(">=");
    private static final Token EQ = PUNCT.of("==");
    private static final Token NE = PUNCT.of("!=");

    private static final Token QUESTION = PUNCT.of("?");
    private static final Token COLON = PUNCT.of(":");

    private static final Token ASSIGNMENT = PUNCT.of("=");
    private static final Token SEMICOLON = PUNCT.of(";");

    private static final Token LABEL_ABSENT = PUNCT.of(LabelValues.ABSENT);

    private static final Token RETURN = KEYWORD.of("return");
    private static final Token BY = KEYWORD.of("by");
    private static final Token LET = KEYWORD.of("let");

    // use is not a keyword since legacy code `let use = {...}` should still compile
    private static final Token USE = IDENT.of("use");

    private SelParser(SelLexer lexer, boolean astRangesEnabled) {
        this.lexer = lexer;
        this.astRangesEnabled = astRangesEnabled;
    }

    public SelParser(String string) {
        this(string, true);
    }

    public SelParser(String string, boolean astRangesEnabled) {
        this(new SelLexer(string), astRangesEnabled);
    }

    public ScannedToken consume(TokenType tokenType) {
        ScannedToken token = lexer.nextToken();
        if (!token.matches(tokenType)) {
            throw error(token, "Expected " + tokenType + " token but " + token.shortString() + " found");
        }
        return token;
    }

    private ScannedToken consumeOneOf(Set<TokenType> tokenTypes) {
        ScannedToken token = lexer.nextToken();
        if (!tokenTypes.contains(token.getTokenType())) {
            throw error(token, "Expected one of " + tokenTypes + " token but " + token.shortString() + " found");
        }
        return token;
    }

    public ScannedToken consume(Token expectedToken) {
        ScannedToken token = lexer.nextToken();
        if (!token.matches(expectedToken)) {
            throw error(token, "Expected " + expectedToken + " token but " + token.shortString() + " found");
        }
        return token;
    }

    private boolean lookaheadIs(Token token) {
        return lexer.lookaheadToken().matches(token);
    }

    private boolean lookaheadIs(TokenType tokenType) {
        return lexer.lookaheadToken().matches(tokenType);
    }

    private ScannedToken consumeIf(Token token) {
        if (lookaheadIs(token)) {
            return lexer.nextToken();
        }
        return token.getTokenType().of(SelScanner.FAILED);
    }

    private ScannedToken consumeIfOneOf(TokenType... tokenTypes) {
        ScannedToken next = lexer.lookaheadToken();
        for (TokenType tokenType : tokenTypes) {
            if (next.matches(tokenType)) {
                return lexer.nextToken();
            }
        }
        return next.getTokenType().of(SelScanner.FAILED);
    }

    private ScannedToken consumeIfOneOf(Token... tokens) {
        ScannedToken next = lexer.lookaheadToken();
        for (Token token : tokens) {
            if (next.matches(token)) {
                return lexer.nextToken();
            }
        }
        return next.getTokenType().of(SelScanner.FAILED);
    }

    // Valid cases:
    // (a) -> a
    // a -> a
    // (a, b) -> a || b
    private boolean lookaheadIsLambda() {
        SelLexer copy = lexer.copy();

        ScannedToken identOrPunct = copy.nextToken();

        // a -> a
        if (identOrPunct.matches(IDENT)) {
            ScannedToken lambdaPunct = copy.nextToken();
            return lambdaPunct.matches(ARROW);
        }

        // (a,...,aN) -> a
        if (!identOrPunct.matches(OPENING_PAREN)) {
            return false;
        }

        while (true) {
            ScannedToken next = copy.nextToken();

            if (!next.matches(IDENT)) {
                return false;
            }

            next = copy.nextToken();
            if (next.matches(COMMA)) {
                continue;
            }

            if (!next.matches(CLOSING_PAREN)) {
                return false;
            }

            break;
        }

        return copy.nextToken().matches(ARROW);
    }

    // Valid cases:
    // response_time{}
    // 'response_time'{}
    // "response_time"{}
    // response_time{bin='*', user='test'}
    // response_time{project='yt', bin='*', user='test'}
    // {sensor='response_time', project='yt', bin='*', user='test'}
    // {"sensor"="response_time", "project"="yt", "bin"="*", "user"="test"}
    private boolean lookaheadIsSelectors() {
        SelLexer copy = lexer.copy();

        ScannedToken token = copy.nextToken();

        if (token.matches(IDENT) || token.matches(IDENT_WITH_DOTS) || token.matches(STRING)) {
            token = copy.nextToken();
        }

        return token.matches(OPENING_BRACE);
    }

    private boolean lookaheadIsObject() {
        SelLexer copy = lexer.copy();

        ScannedToken token = copy.nextToken();
        if (token.matches(OPENING_BRACE)) {
            token = copy.nextToken();

            if (token.matches(CLOSING_BRACE)) {
                return true;
            }

            if (token.matches(IDENT)) {
                return copy.nextToken().matches(COLON);
            }
        }

        return false;
    }

    private static double parseSelDouble(String s) {
        String digits = s;
        long multiplier = 1;
        if (Character.isAlphabetic(s.charAt(s.length() - 1))) {
            multiplier = MetricPrefix.byLetter(s.substring(s.length() - 1)).getMultiplier();
            digits = s.substring(0, s.length() - 1);
        }
        return Double.parseDouble(digits) * multiplier;
    }

    private Ast consumeIdentOrCallOrNumberOrLambdaOrDuration() {
        if (lookaheadIs(OPENING_BRACKET)) {
            return consumeVector();
        }
        if (lookaheadIsObject()) {
            return consumeObject();
        }
        if (lookaheadIsSelectors()) {
            return consumeSelectors();
        }
        if (lookaheadIs(DURATION)) {
            return consumeDuration();
        }
        if (lookaheadIs(DOUBLE)) {
            return consumeDouble();
        }
        if (lookaheadIs(STRING)) {
            return consumeString();
        }
        if (lookaheadIsLambda()) {
            return consumeLambda();
        }

        return consumeIdentOrCallOrAggregation();
    }

    private AstValueDuration consumeDuration() {
        ScannedToken token = consume(DURATION);
        return new AstValueDuration(token.getRange(astRangesEnabled), DurationParser.parseDuration(token.getValue()));
    }

    private AstValueDouble consumeDouble() {
        ScannedToken token = consume(DOUBLE);
        return new AstValueDouble(token.getRange(astRangesEnabled), parseSelDouble(token.getValue()));
    }

    private AstValue consumeString() {
        return parseLiteral(consume(STRING));
    }

    private AstValue parseLiteral(ScannedToken token) {
        String literal = token.getValue();
        PositionRange range = token.getRange(astRangesEnabled);
        return parseLiteral(range, literal);
    }

    private AstValue parseLiteral(PositionRange range, String literal) {
        if (InterpolatedString.isInterpolatedString(literal)) {
            return new AstInterpolatedString(range, InterpolatedString.parse(literal));
        }
        return new AstValueString(range, literal);
    }

    private AstLambda consumeLambda() {
        // (a,...,aN) -> a
        ScannedToken opening = consumeIf(OPENING_PAREN);
        if (opening.isPresent()) {
            List<String> params = new ArrayList<>(2);
            do {
                params.add(consumeIdent().getIdent());
            } while (consumeIf(COMMA).isPresent());
            consume(CLOSING_PAREN);
            consume(ARROW);
            Ast body = consumeExpr();
            return new AstLambda(PositionRange.convexHull(opening.getRange(astRangesEnabled), body.getRange()), params, body);
        }

        // a -> a
        AstIdent param = consumeIdent();
        String paramName = param.getIdent();
        consume(ARROW);
        Ast body = consumeExpr();
        return new AstLambda(PositionRange.convexHull(param.getRange(), body.getRange()), Collections.singletonList(paramName), body);
    }

    private static final Set<TokenType> IDENT_OR_STRING = Set.of(IDENT, STRING);

    private void consumeLabelSelectorList(List<AstSelector> labelSelectors) {
        do {
            ScannedToken keyToken = consumeOneOf(IDENT_OR_STRING);
            SelectorType type = SelectorType.forOperator(consume(PUNCT).getValue());
            // host=-
            ScannedToken valueToken = consumeIf(LABEL_ABSENT);
            if (!valueToken.isPresent()) {
                valueToken = consumeOneOf(IDENT_OR_STRING);
            }
            AstValue key = parseLiteral(keyToken);
            AstValue value = parseLiteral(valueToken);
            PositionRange range = PositionRange.convexHull(key.getRange(), value.getRange());
            labelSelectors.add(new AstSelector(range, key, value, type));
        } while (consumeIf(COMMA).isPresent());
    }

    private List<Ast> consumeSequenceUntil(Token closing) {
        List<Ast> args = new ArrayList<>();
        if (!lookaheadIs(closing)) {
            args.add(consumeExpr());
            while (consumeIf(COMMA).isPresent()) {
                args.add(consumeExpr());
            }
        }
        return args;
    }

    private Ast consumeVector() {
        ScannedToken opening = consume(OPENING_BRACKET);
        List<Ast> args = consumeSequenceUntil(CLOSING_BRACKET);
        ScannedToken closing = consume(CLOSING_BRACKET);
        AstIdent func = new AstIdent(opening.getRange(astRangesEnabled), "as_vector");
        return new AstCall(PositionRange.convexHull(opening.getRange(astRangesEnabled), closing.getRange(astRangesEnabled)),
                func, args);
    }

    private AstSelectors consumeSelectors() {
        List<AstSelector> labelSelectors = new ArrayList<>(3);
        ScannedToken nameToken = consumeIfOneOf(STRING, IDENT_WITH_DOTS, IDENT);

        ScannedToken opening = consume(OPENING_BRACE);
        if (!lookaheadIs(CLOSING_BRACE)) {
            consumeLabelSelectorList(labelSelectors);
        }
        ScannedToken closing = consume(CLOSING_BRACE);

        if (nameToken.isPresent()) {
            PositionRange pos = PositionRange.convexHull(nameToken.getRange(astRangesEnabled), closing.getRange(astRangesEnabled));
            return new AstSelectors(pos, nameToken.getValue(), labelSelectors);
        } else {
            PositionRange pos = PositionRange.convexHull(opening.getRange(astRangesEnabled), closing.getRange(astRangesEnabled));
            return new AstSelectors(pos, "", labelSelectors);
        }
    }

    private AstObject consumeObject() {
        Map<String, Ast> object = new HashMap<>(3);

        ScannedToken opening = consume(OPENING_BRACE);
        if (!lookaheadIs(CLOSING_BRACE)) {
            do {
                ScannedToken keyToken = consume(IDENT);

                if (object.containsKey(keyToken.getValue())) {
                    throw error(keyToken, "Duplicate key " + keyToken + " in the object");
                }

                consume(COLON);

                Ast value = consumeExpr();

                object.put(keyToken.getValue(), value);
            } while (consumeIf(COMMA).isPresent());
        }
        ScannedToken closing = consume(CLOSING_BRACE);
        return new AstObject(PositionRange.convexHull(opening.getRange(astRangesEnabled), closing.getRange(astRangesEnabled)), object);
    }

    private Ast consumeIdentOrCallOrAggregation() {
        ScannedToken identToken = consume(IDENT);
        AstIdent ident = new AstIdent(identToken.getRange(astRangesEnabled), identToken.getValue());
        if (consumeIf(OPENING_PAREN).isPresent()) {
            List<Ast> ast = consumeSequenceUntil(CLOSING_PAREN);
            ScannedToken closing = consume(CLOSING_PAREN);
            AstCall call = new AstCall(PositionRange.convexHull(ident.getRange(), closing.getRange(astRangesEnabled)), ident, ast);
            ScannedToken by = consumeIf(BY);
            if (by.isPresent()) {
                return consumeTimeOrLabelAggregation(by, call);
            } else {
                return call;
            }
        } else {
            return ident;
        }
    }

    private Ast consumeTimeOrLabelAggregation(ScannedToken by, AstCall aggr) {
        if (aggr.args.size() != 1) {
            throw error(by, "`by` keyword available only for aggregation function " +
                    "with single argument, applied to `" + aggr.getFunc().getIdent() + "`");
        }

        if (lookaheadIs(DURATION)) {
            return consumeGroupByTime(by, aggr);
        }

        return consumeGroupByLabel(by, aggr);
    }

    private static AstValueString identToStr(AstIdent ident) {
        return new AstValueString(ident.getRange(), ident.getIdent());
    }

    private Ast consumeGroupByTime(ScannedToken by, AstCall aggr) {
        Ast duration = consumeDuration();
        List<Ast> args = ImmutableList.<Ast>builder()
                .add(duration)
                .add(identToStr(aggr.func))
                .addAll(aggr.args)
                .build();

        AstIdent groupByTime = new AstIdent(by.getRange(astRangesEnabled), "group_by_time");
        return new AstCall(PositionRange.convexHull(aggr.getRange(), duration.getRange()), groupByTime, args);
    }

    private Ast consumeGroupByLabel(ScannedToken by, AstCall aggr) {
        // by (a,...,aN)
        Ast argLabels;
        ScannedToken opening = consumeIf(OPENING_PAREN);
        if (opening.isPresent()) {
            List<Ast> labels = new ArrayList<>(2);
            do {
                ScannedToken label = consumeOneOf(IDENT_OR_STRING);
                labels.add(parseLiteral(label));
            } while (consumeIf(COMMA).isPresent());
            ScannedToken closing = consume(CLOSING_PAREN);
            if (labels.size() == 1) {
                argLabels = labels.get(0);
            } else {
                AstIdent asVector = new AstIdent(opening.getRange(astRangesEnabled), "as_vector");
                argLabels = new AstCall(PositionRange.convexHull(opening.getRange(astRangesEnabled), closing.getRange(astRangesEnabled)), asVector, labels);
            }
        } else {
            ScannedToken label = consumeOneOf(IDENT_OR_STRING);
            argLabels = parseLiteral(label);
        }

        Ast argMetrics = aggr.args.get(0);
        AstIdent groupLines = new AstIdent(by.getRange(astRangesEnabled), "group_lines");
        return new AstCall(PositionRange.convexHull(aggr.getRange(), argLabels.getRange()),
                groupLines, ImmutableList.of(identToStr(aggr.func), argLabels, argMetrics));
    }

    private AstIdent consumeIdent() {
        ScannedToken token = consume(IDENT);
        return new AstIdent(token.getRange(astRangesEnabled), token.getValue());
    }

    private Ast consumeAtom() {
        ScannedToken opening = consumeIf(OPENING_PAREN);
        if (opening.isPresent()) {
            Ast expr = consumeExpr();
            ScannedToken closing = consume(CLOSING_PAREN);
            PositionRange range = PositionRange.convexHull(opening.getRange(astRangesEnabled), closing.getRange(astRangesEnabled));
            return expr.withRange(range);
        }
        return consumeIdentOrCallOrNumberOrLambdaOrDuration();
    }

    private Ast consumeUnaryOp() {
        return consumeUnaryOp(this::consumeAtom, PLUS, MINUS);
    }

    private interface AstSupplier extends Supplier<Ast> {
    }

    private Ast consumeBinOp(AstSupplier down, Token... ops) {
        Ast r = down.get();
        while (true) {
            ScannedToken opToken = consumeIfOneOf(ops);
            if (!opToken.isPresent()) {
                break;
            }
            Ast right = down.get();
            PositionRange range = PositionRange.convexHull(r.getRange(), right.getRange());
            r = new AstBinOp(range, r, right, new AstOp(opToken.getRange(astRangesEnabled), opToken.getValue()));
        }
        return r;
    }

    private Ast consumeUnaryOp(AstSupplier down, Token... ops) {
        ScannedToken prefix = consumeIfOneOf(ops);
        if (prefix.isPresent()) {
            Ast param = down.get();
            AstOp astOp = new AstOp(prefix.getRange(astRangesEnabled), prefix.getValue());
            PositionRange range = PositionRange.convexHull(astOp.getRange(), param.getRange());
            return new AstUnaryOp(range, param, astOp);
        }
        return down.get();
    }

    private Ast consumeTerm() {
        return consumeBinOp(this::consumeUnaryOp, DIV, MUL);
    }

    private Ast consumeArith() {
        return consumeBinOp(this::consumeTerm, PLUS, MINUS);
    }

    private Ast consumeComparison() {
        return consumeBinOp(this::consumeArith, LT, GT, LE, GE, EQ, NE);
    }

    private Ast consumeNot() {
        return consumeUnaryOp(this::consumeComparison, NOT);
    }

    private Ast consumeAnd() {
        return consumeBinOp(this::consumeNot, AND);
    }

    private Ast consumeOr() {
        return consumeBinOp(this::consumeAnd, OR);
    }

    private Ast consumeTernaryOp(AstSupplier down) {
        Ast r = down.get();
        if (consumeIf(QUESTION).isPresent()) {
            Ast ifTrue = down.get();
            consume(COLON);
            Ast ifFalse = down.get();
            PositionRange range = PositionRange.convexHull(r.getRange(), ifFalse.getRange());
            return new AstTernaryOp(range, r, ifTrue, ifFalse);
        } else {
            return r;
        }
    }

    private Ast consumeExpr() {
        if (lookaheadIsLambda()) {
            return consumeLambda();
        }
        return consumeTernaryOp(this::consumeOr);
    }

    private AstAssignment consumeAssignment() {
        ScannedToken opening = consume(LET);

        ScannedToken assignedVar = consume(IDENT);
        String varName = assignedVar.getValue();

        consume(ASSIGNMENT);

        Ast expr = consumeExpr();

        ScannedToken closing = consume(SEMICOLON);

        return new AstAssignment(PositionRange.convexHull(opening.getRange(astRangesEnabled), closing.getRange(astRangesEnabled)), varName, expr);
    }

    private AstAnonymous consumeAnonymous() {
        Ast expr = consumeExpr();
        ScannedToken closing = consume(SEMICOLON);
        return new AstAnonymous(PositionRange.convexHull(expr.getRange(), closing.getRange(astRangesEnabled)), expr);
    }

    private AstUse consumeUse() {
        ScannedToken opening = consume(USE);

        List<AstSelector> labelSelectors = new ArrayList<>(3);

        consume(OPENING_BRACE);

        consumeLabelSelectorList(labelSelectors);

        consume(CLOSING_BRACE);

        ScannedToken closing = consume(SEMICOLON);

        return new AstUse(PositionRange.convexHull(
                    opening.getRange(astRangesEnabled),
                    closing.getRange(astRangesEnabled)
                ), labelSelectors);
    }

    public List<AstStatement> parseBlock() {
        ArrayList<AstStatement> r = new ArrayList<>();
        while (true) {
            ScannedToken token = lexer.lookaheadToken();
            if (token.matches(EOF)) {
                break;
            }
            if (token.matches(LET)) {
                r.add(consumeAssignment());
            } else if (token.matches(USE)) {
                if (!r.isEmpty()) {
                    throw error(token, "use statement can only be placed at the beginning of the program");
                }
                r.add(consumeUse());
            } else {
                r.add(consumeAnonymous());
            }
        }

        return r;
    }

    public Ast parseExpr() {
        Ast r = consumeExpr();
        consume(EOF);
        return r;
    }

    public ProgramWithReturn parseProgramWithReturn() {
        List<AstStatement> statements = new ArrayList<>();

        final Ast expr;

        while (true) {
            ScannedToken token = lexer.lookaheadToken();
            if (token.matches(LET)) {
                statements.add(consumeAssignment());
            } else if (token.matches(USE)) {
                if (!statements.isEmpty()) {
                    throw error(token, "use statement can only be placed at the beginning of the program");
                }
                statements.add(consumeUse());
            } else {
                break;
            }
        }

        if (consumeIf(RETURN).isPresent()) {
            expr = consumeExpr();
            consume(SEMICOLON);
            consume(EOF);
        } else {
            if (lookaheadIs(EOF)) {
                ScannedToken eof = consume(EOF);
                throw error(eof, "Expecting `return <expression>;` or just `<expression>` at the end of the program");
            }
            expr = parseExpr();
        }

        return new ProgramWithReturn(statements, expr);
    }

    private ParserException error(ScannedToken parsed, String msg) {
        return new ParserException(astRangesEnabled, parsed, msg);
    }

    @Override
    public String toString() {
        return "SelParser{" +
            "lexer=" + lexer +
            '}';
    }
}
