package ru.yandex.solomon.expression.compile;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import javax.annotation.Nullable;
import javax.annotation.ParametersAreNonnullByDefault;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;

import ru.yandex.solomon.expression.SelParser;
import ru.yandex.solomon.expression.ast.Ast;
import ru.yandex.solomon.expression.ast.AstAnonymous;
import ru.yandex.solomon.expression.ast.AstAssignment;
import ru.yandex.solomon.expression.ast.AstBinOp;
import ru.yandex.solomon.expression.ast.AstCall;
import ru.yandex.solomon.expression.ast.AstIdent;
import ru.yandex.solomon.expression.ast.AstInterpolatedString;
import ru.yandex.solomon.expression.ast.AstLambda;
import ru.yandex.solomon.expression.ast.AstObject;
import ru.yandex.solomon.expression.ast.AstSelector;
import ru.yandex.solomon.expression.ast.AstSelectors;
import ru.yandex.solomon.expression.ast.AstStatement;
import ru.yandex.solomon.expression.ast.AstTernaryOp;
import ru.yandex.solomon.expression.ast.AstUnaryOp;
import ru.yandex.solomon.expression.ast.AstUse;
import ru.yandex.solomon.expression.ast.AstValue;
import ru.yandex.solomon.expression.ast.AstValueDouble;
import ru.yandex.solomon.expression.ast.AstValueDuration;
import ru.yandex.solomon.expression.ast.AstValueInterval;
import ru.yandex.solomon.expression.ast.AstValueString;
import ru.yandex.solomon.expression.ast.RecursiveAstVisitor;
import ru.yandex.solomon.expression.exceptions.CompilerException;
import ru.yandex.solomon.expression.exceptions.InternalCompilerException;
import ru.yandex.solomon.expression.expr.EvalContextImpl;
import ru.yandex.solomon.expression.expr.SelExpr;
import ru.yandex.solomon.expression.expr.SelExprFuncCall;
import ru.yandex.solomon.expression.expr.SelExprInterpolatedString;
import ru.yandex.solomon.expression.expr.SelExprObject;
import ru.yandex.solomon.expression.expr.SelExprOpCall;
import ru.yandex.solomon.expression.expr.SelExprParam;
import ru.yandex.solomon.expression.expr.SelExprSelector;
import ru.yandex.solomon.expression.expr.SelExprSelectors;
import ru.yandex.solomon.expression.expr.SelExprTern;
import ru.yandex.solomon.expression.expr.SelExprValue;
import ru.yandex.solomon.expression.expr.SelFunctions;
import ru.yandex.solomon.expression.expr.SelOperators;
import ru.yandex.solomon.expression.type.SelType;
import ru.yandex.solomon.expression.type.SelTypes;
import ru.yandex.solomon.expression.value.SelValue;
import ru.yandex.solomon.expression.value.SelValueBoolean;
import ru.yandex.solomon.expression.value.SelValueDouble;
import ru.yandex.solomon.expression.value.SelValueDuration;
import ru.yandex.solomon.expression.value.SelValueInterval;
import ru.yandex.solomon.expression.value.SelValueLambda;
import ru.yandex.solomon.expression.value.SelValueString;
import ru.yandex.solomon.expression.version.SelVersion;

/**
 * @author Stepan Koltsov
 */
@ParametersAreNonnullByDefault
public class SelCompiler extends RecursiveAstVisitor<CompileContext, SelExpr> {
    private final SelVersion version;

    public SelCompiler(SelVersion version) {
        this.version = version;
    }

    public SelVersion getVersion() {
        return version;
    }

    public SelExpr compileExpr(Ast ast, CompileContext context) {
        @Nullable SelExpr result = visit(ast, context);
        if (result == null) {
            throw new CompilerException(ast.getRange(), "unknown ast: " + ast);
        }
        return result;
    }

    @VisibleForTesting
    public SelExpr compileExpr(String program, CompileContext context) {
        Ast ast = new SelParser(program).parseExpr();
        return compileExpr(ast, context);
    }

    public List<SelStatement> compileBlock(List<AstStatement> stmts, CompileContext context) {
        ArrayList<SelStatement> r = new ArrayList<>(stmts.size());

        for (AstStatement stmt : stmts) {
            SelStatement compiled;
            if (stmt instanceof AstAssignment assignment) {
                compiled = new SelAssignment(assignment.getRange(), assignment.getIdent(), compileExpr(assignment.getExpr(), context));
            } else if (stmt instanceof AstAnonymous anon) {
                compiled = new SelAnonymous(anon.getRange(), compileExpr(anon.getExpr(), context));
            } else if (stmt instanceof AstUse use) {
                compiled = new SelUse(use.getRange(), compileSelectors(use.getLabelSelectors(), context));
            } else {
                throw new InternalCompilerException("Don't know how to compile " + stmt);
            }
            r.add(compiled);
            compiled.changeCompileContext(context);
        }

        return r;
    }

    @Override
    protected SelExprValue visitValue(AstValue astValue, CompileContext context) {
        SelValue value;
        if (astValue instanceof AstValueDouble) {
            value = new SelValueDouble(((AstValueDouble) astValue).getValue());
        } else if (astValue instanceof AstValueString) {
            value = new SelValueString(((AstValueString) astValue).getValue());
        } else if (astValue instanceof AstValueDuration) {
            value = new SelValueDuration(((AstValueDuration) astValue).getValue());
        } else if (astValue instanceof AstValueInterval) {
            value = new SelValueInterval(((AstValueInterval) astValue).getInterval());
        } else {
            throw new CompilerException(astValue.getRange(), "Unknown ast value: " + astValue);
        }
        return new SelExprValue(astValue, value);
    }

    private List<SelExprSelector> compileSelectors(List<AstSelector> labels, CompileContext context) {
        List<SelExprSelector> compiledSelectors = new ArrayList<>(labels.size());
        for (AstSelector selector : labels) {
            SelExpr key = compileExpr(selector.getKey(), context);
            SelExpr value = compileExpr(selector.getValue(), context);
            SelExprSelector compiled = new SelExprSelector(selector, key, value, selector.getType());
            compiledSelectors.add(compiled);
        }
        return compiledSelectors;
    }

    @Override
    protected SelExprSelectors visitSelectors(AstSelectors selectors, CompileContext context) {
        String nameSelector = selectors.getNameSelector();
        List<SelExprSelector> compiledSelectors = compileSelectors(selectors.getSelectors(), context);
        return new SelExprSelectors(selectors, SelTypes.GRAPH_DATA_VECTOR, nameSelector, compiledSelectors);
    }

    @Override
    protected SelExprObject visitObject(AstObject object, CompileContext context) {
        Map<String, Ast> astByKey = object.getObject();

        Map<String, SelExpr> result = new HashMap<>(astByKey.size());

        for (Map.Entry<String, Ast> entry : astByKey.entrySet()) {
            String key = entry.getKey();
            Ast ast = entry.getValue();
            SelExpr expr = compileExpr(ast, context);
            result.put(key, expr);
        }

        return new SelExprObject(object, result);
    }

    @Override
    protected SelExprFuncCall visitCall(AstCall call, CompileContext context) {
        String name = call.getFunc().getIdent();
        SelFunctions.REGISTRY.ensureHasFunction(version, call.getFunc());
        if (name.equals("load1") && context.getDeprOpts().isDropLoad1()) {
            throw new CompilerException(call.getRange(), "function \"load1\" is deprecated, use new syntax: https://nda.ya.ru/3UW4MW");
        }

        if (name.equals("load") && context.getDeprOpts().isDropLoad()) {
            throw new CompilerException(call.getRange(), "function \"load\" is deprecated, use new syntax: https://nda.ya.ru/3UW4MW");
        }

        if (name.equals("map")) {
            return compileMap(call, context);
        }

        if (name.equals("group_by_labels")) {
            return compileGroupByLabels(call, context);
        }

        List<SelExpr> params = new ArrayList<>(call.args.size());
        for (Ast ast : call.args) {
            params.add(compileExpr(ast, context));
        }
        return SelFunctions.exprCall(version, call, params);
    }

    private SelExprFuncCall compileMap(AstCall call, CompileContext context) {
        if (call.args.size() != 2) {
            throw new IllegalArgumentException("wrong number of params for map, expected 2, got: " + call.args);
        }

        SelExpr arg1 = compileExpr(call.args.get(0), context);
        if (!arg1.type().isVector()) {
            throw new IllegalArgumentException("map p0 must be vector, got: " + arg1);
        }

        SelType elementType = arg1.type().vector().elementType;

        AstLambda lambda = call.args.get(1).lambda();

        SelValueLambda valueLambda = compileLambda(lambda, List.of(elementType));

        List<SelExpr> list = List.of(arg1, new SelExprValue(lambda, valueLambda));
        return SelFunctions.exprCall(version, call, list);
    }

    private SelExprFuncCall compileGroupByLabels(AstCall call, CompileContext context) {
        if (call.args.size() != 3) {
            throw new CompilerException(call.getRange(), "wrong number of params for group_by_labels, expected 3, got: " + call.args);
        }

        SelExpr dataSet = compileExpr(call.args.get(0), context);
        SelExpr labels = compileExpr(call.args.get(1), context);
        AstLambda lambda = call.args.get(2).lambda();
        SelValueLambda lambdaParam = compileLambda(lambda, ImmutableList.of(SelTypes.GRAPH_DATA_VECTOR));
        return SelFunctions.exprCall(version, call, List.of(dataSet, labels, new SelExprValue(lambda, lambdaParam)));
    }

    private SelValueLambda compileLambda(AstLambda lambda, List<SelType> parameterTypes) {
        if (lambda.paramNames.size() != parameterTypes.size()) {
            throw new IllegalArgumentException("different lambda param name sizes: " + lambda.paramNames.size() + "!=" + parameterTypes.size());
        }

        HashMap<String, SelType> params = new HashMap<>();

        for (int i = 0; i < lambda.paramNames.size(); ++i) {
            params.put(lambda.paramNames.get(i), parameterTypes.get(i));
        }

        if (params.size() != lambda.paramNames.size()) {
            throw new IllegalArgumentException("non-unique param names: " + lambda.paramNames);
        }

        SelExpr body = compileExpr(lambda.body, new CompileContext(params));
        return new SelValueLambda(parameterTypes, lambda.paramNames, body);
    }

    @Override
    protected SelExpr visitIdent(AstIdent ident, CompileContext context) {
        if (ident.getIdent().equals("true")) {
            return new SelExprValue(ident, SelValueBoolean.TRUE);
        } else if (ident.getIdent().equals("false")) {
            return new SelExprValue(ident, SelValueBoolean.FALSE);
        }

        return new SelExprParam(ident, context.getType(ident.getIdent()));
    }

    @Override
    protected SelExprInterpolatedString visitInterpolatedString(AstInterpolatedString ast, CompileContext context) {
        return new SelExprInterpolatedString(ast);
    }

    @Override
    protected SelExprOpCall visitBinOp(AstBinOp binOp, CompileContext context) {
        SelExpr left = compileExpr(binOp.getLeft(), context);
        SelExpr right = compileExpr(binOp.getRight(), context);
        return SelOperators.exprCall(version, binOp, binOp.getOp(), List.of(left, right));
    }

    @Override
    protected SelExprTern visitTernaryOp(AstTernaryOp binOp, CompileContext context) {
        SelExpr condition = compileExpr(binOp.getCondition(), context);
        SelExpr left = compileExpr(binOp.getLeft(), context);
        SelExpr right = compileExpr(binOp.getRight(), context);

        return new SelExprTern(binOp, condition, left, right);
    }

    @Override
    protected SelExprOpCall visitUnaryOp(AstUnaryOp unaryOp, CompileContext context) {
        SelExpr param = compileExpr(unaryOp.getParam(), context);
        return SelOperators.exprCall(version, unaryOp, unaryOp.getOp(), List.of(param));
    }

    public SelValue eval(String program, Map<String, SelValue> params) {
        EvalContextImpl context = new EvalContextImpl(params, version);
        SelExpr expr = compileExpr(new SelParser(program).parseExpr(), context.compileContext());
        return expr.eval(context);
    }
}
