package ru.yandex.solomon.expression.expr;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;

import javax.annotation.Nullable;
import javax.annotation.ParametersAreNonnullByDefault;

import ru.yandex.solomon.expression.ast.AstInterpolatedString;
import ru.yandex.solomon.expression.compile.CompileContext;
import ru.yandex.solomon.expression.compile.SelAssignment;
import ru.yandex.solomon.expression.compile.SelStatement;
import ru.yandex.solomon.expression.exceptions.PreparingException;
import ru.yandex.solomon.expression.type.SelType;
import ru.yandex.solomon.expression.value.SelValue;
import ru.yandex.solomon.expression.value.SelValueString;
import ru.yandex.solomon.expression.version.SelVersion;
import ru.yandex.solomon.labels.InterpolatedString;

/**
 * @author Stepan Koltsov
 */
@ParametersAreNonnullByDefault
public class SelExprVisitorFoldConstants extends SelExprRecurseVisitor {
    static class ParamMaybeKnownValue {
        private final String name;
        @Nullable
        private final SelValue knownValue;

        ParamMaybeKnownValue(String name, SelType type, @Nullable SelValue knownValue) {
            if (knownValue != null) {
                if (!type.equals(knownValue.type())) {
                    throw new AssertionError("different types: " + type + "!=" + knownValue.type());
                }
            }
            this.name = name;
            this.knownValue = knownValue;
        }

        ParamMaybeKnownValue(String name, SelType type) {
            this(name, type, null);
        }

        ParamMaybeKnownValue(String name, SelValue knownValue) {
            this(name, knownValue.type(), knownValue);
        }
    }

    private final Map<String, ParamMaybeKnownValue> context;
    private final SelVersion version;

    SelExprVisitorFoldConstants(List<ParamMaybeKnownValue> params, SelVersion version) {
        context = params.stream().collect(Collectors.toMap(p -> p.name, Function.identity()));
        this.version = version;
    }

    @Override
    public SelExpr visitParam(SelExprParam param) {
        ParamMaybeKnownValue paramMaybeKnownValue = Objects.requireNonNull(
            context.get(param.getName()),
            "unknown param: " + param.getName());

        if (paramMaybeKnownValue.knownValue != null) {
            if (!paramMaybeKnownValue.knownValue.type().equals(param.type())) {
                throw new PreparingException(param.getRange(), "different types: " + paramMaybeKnownValue.knownValue.type() + "!=" + param.type());
            }
            return new SelExprValue(param.getSourceAst(), paramMaybeKnownValue.knownValue);
        }
        return super.visitParam(param);
    }

    @Override
    public SelExpr visitOp(SelExprOpCall fnCall) {
        fnCall = fnCall.mapParams(e -> e.visit(this));
        for (var arg : fnCall.getParams()) {
            if (!(arg instanceof SelExprValue)) {
                return fnCall;
            }
        }
        SelValue result = fnCall.evalInternal(new EvalContextImpl(Map.of(), version));
        return new SelExprValue(fnCall.getSourceAst(), result);
    }

    @Override
    public SelExpr visitFn(SelExprFuncCall fnCall) {
        fnCall = fnCall.mapParams(e -> e.visit(this));
        var fn = fnCall.getFunc();
        if (!fn.isPure()) {
            return fnCall;
        }
        for (var arg : fnCall.getArgs()) {
            if (!(arg instanceof SelExprValue)) {
                return fnCall;
            }
        }
        return new SelExprValue(fnCall.getSourceAst(), fnCall.evalInternal(new EvalContextImpl(Map.of(), version)));
    }

    @Override
    public SelExpr visitCondition(SelExprTern expr) {
        expr = expr.mapParams(e -> e.visit(this));
        if (expr.getCondition() instanceof SelExprValue value) {
            if (value.getValue().castToBoolean().getValue()) {
                return expr.getIfTrue();
            } else {
                return expr.getIfFalse();
            }
        }
        return expr;
    }

    @Override
    public SelExpr visitInterpolatedString(SelExprInterpolatedString interpolatedString) {

        InterpolatedString partialEval = interpolatedString.getSourceAst().interpolatedString.partialEval(n -> {
            ParamMaybeKnownValue maybeKnownValue = context.get(n);
            if (maybeKnownValue == null || maybeKnownValue.knownValue == null) {
                return Optional.empty();
            } else {
                return Optional.of(maybeKnownValue.knownValue.convertToString());
            }
        });

        Optional<String> constant = partialEval.constant();
        if (constant.isPresent()) {
            return new SelExprValue(interpolatedString.getSourceAst(), new SelValueString(constant.get()));
        } else {
            var ast = new AstInterpolatedString(interpolatedString.getRange(), partialEval);
            return new SelExprInterpolatedString(ast);
        }
    }

    @Override
    public SelExprSelector visitSelector(SelExprSelector selector) {
        SelExpr key = selector.getKey().visit(this);
        SelExpr value = selector.getValue().visit(this);
        return new SelExprSelector(selector.getSourceAst(), key, value, selector.getType());
    }

    @Override
    public SelExpr visitSelectors(SelExprSelectors selectors) {
        String nameSelector = selectors.getNameSelector();

        List<SelExprSelector> patched = selectors.getSelectors().stream()
                .map(this::visitSelector)
                .collect(Collectors.toList());

        return new SelExprSelectors(selectors.getSourceAst(), selectors.type(), nameSelector, patched);
    }

    public static List<SelStatement> foldBlock(
            SelVersion version,
            List<SelStatement> block,
            CompileContext context,
            Map<String, String> externalSelectors)
    {
        ArrayList<SelStatement> r = new ArrayList<>(block.size());

        HashMap<String, ParamMaybeKnownValue> knownParams = new HashMap<>();
        for (Map.Entry<String, SelType> e : context.getParams().entrySet()) {
            knownParams.put(e.getKey(), new ParamMaybeKnownValue(e.getKey(), e.getValue()));
        }
        externalSelectors.forEach((k, v) -> knownParams.put(k, new ParamMaybeKnownValue(k, new SelValueString(v))));

        for (SelStatement statement : block) {
            var visitor = new SelExprVisitorFoldConstants(new ArrayList<>(knownParams.values()), version);

            SelStatement patched = statement.visit(visitor);

            if (patched instanceof SelAssignment assignment) {
                ParamMaybeKnownValue value;
                SelExpr expr = assignment.getExpr();
                if (expr instanceof SelExprValue exprValue) {
                    value = new ParamMaybeKnownValue(assignment.getIdent(), exprValue.getValue());
                } else {
                    value = new ParamMaybeKnownValue(assignment.getIdent(), expr.type());
                }
                knownParams.put(assignment.getIdent(), value);
            }

            r.add(patched);
        }

        return r;
    }
}
