package ru.yandex.solomon.expression.expr;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import javax.annotation.ParametersAreNonnullByDefault;

import ru.yandex.solomon.expression.analytics.GraphDataLoadRequest;
import ru.yandex.solomon.expression.ast.Ast;
import ru.yandex.solomon.expression.ast.AstCall;
import ru.yandex.solomon.expression.ast.AstIdent;
import ru.yandex.solomon.expression.ast.AstValueInterval;
import ru.yandex.solomon.expression.compile.SelAnonymous;
import ru.yandex.solomon.expression.compile.SelAssignment;
import ru.yandex.solomon.expression.compile.SelStatement;
import ru.yandex.solomon.expression.exceptions.InternalCompilerException;
import ru.yandex.solomon.expression.type.SelType;
import ru.yandex.solomon.expression.type.SelTypes;
import ru.yandex.solomon.expression.value.ArgsList;
import ru.yandex.solomon.expression.value.SelValueInterval;
import ru.yandex.solomon.expression.version.SelVersion;
import ru.yandex.solomon.util.time.Interval;

/**
 * @author Vladimir Gordiychuk
 */
@ParametersAreNonnullByDefault
public class IntervalLoadVisitor extends SelExprRecurseVisitor {
    private final SelVersion version;
    private final Interval originalInterval;
    private final Map<String, Interval> varToLoadInterval = new HashMap<>();
    private Interval currentInterval;
    private List<SelStatement> patchedLines;

    private IntervalLoadVisitor(SelVersion version, Interval originalInterval) {
        this.version = version;
        this.originalInterval = originalInterval;
    }

    private boolean isTimeseries(SelType type) {
        if (type.isGraphData()) {
            return true;
        }

        return type.isVector() && type.vector().elementType == SelTypes.GRAPH_DATA;
    }

    public static SelExpr makeCrop(SelVersion version, SelExpr wrapped, Interval interval) {
        Ast wrappedAst = wrapped.getSourceAst();
        AstValueInterval astInterval = new AstValueInterval(wrappedAst.getRange(), interval);
        AstIdent crop = new AstIdent(wrappedAst.getRange(), "crop");
        AstCall call = new AstCall(crop.getRange(), crop, List.of(wrappedAst, astInterval));

        SelValueInterval selInterval = new SelValueInterval(astInterval.getInterval());
        List<SelExpr> cropParams = List.of(wrapped, new SelExprValue(astInterval, selInterval));
        return SelFunctions.exprCall(version, call, cropParams);
    }

    @Override
    public SelExpr visitParam(SelExprParam param) {
        if (!isTimeseries(param.type())) {
            return param;
        }

        String ident = param.getName();
        Interval previousIdentityInterval = varToLoadInterval.getOrDefault(ident, currentInterval);
        varToLoadInterval.put(ident, Interval.convexHull(currentInterval, previousIdentityInterval));

        return makeCrop(version, param, currentInterval);
    }

    @Override
    public SelExpr visitOp(SelExprOpCall fnCall) {
        return visitNext(fnCall);
    }

    @Override
    public SelExpr visitFn(SelExprFuncCall fnCall) {
        var function = fnCall.getFunc();
        var handler = function.getIntervalHandler();
        var resultInterval = currentInterval;
        if (handler != null) {
            ArgsList values = new ArgsList(fnCall.getRange(), fnCall.getArgs().size());
            for (var argExpr : fnCall.getArgs()) {
                if (argExpr instanceof SelExprValue) {
                    values.add(((SelExprValue) argExpr)::getValue, argExpr.getRange());
                } else {
                    values.add(() -> null, argExpr.getRange());
                }
            }

            currentInterval = handler.apply(currentInterval, values);
        }

        SelExpr result = visitNext(fnCall);
        if (currentInterval.equals(resultInterval)) {
            return result;
        } else {
            return makeCrop(version, result, resultInterval);
        }
    }

    @Override
    public SelExpr visitCondition(SelExprTern expr) {
        return visitNext(expr);
    }

    @Override
    public SelExpr visitObject(SelExprObject object) {
        return visitNext(object);
    }

    private SelExpr visitNext(SelExpr root) {
        return root.mapParams(expr -> {
            Interval temp = currentInterval;
            try {
                return expr.visit(this);
            } finally {
                currentInterval = temp;
            }
        });
    }

    public Interval getInterval(String ident) {
        return varToLoadInterval.getOrDefault(ident, originalInterval);
    }

    private SelStatement visitAssignment(SelAssignment line) {
        String ident = line.getIdent();
        currentInterval = getInterval(ident);
        return line.visit(this);
    }

    private SelStatement visitAnonymous(SelAnonymous line) {
        currentInterval = originalInterval;
        return line.visit(this);
    }

    private SelStatement visitStatement(SelStatement line) {
        if (line instanceof SelAssignment) {
            return visitAssignment((SelAssignment) line);
        } else if (line instanceof SelAnonymous) {
            return visitAnonymous((SelAnonymous) line);
        }

        throw new InternalCompilerException("Don't know how to patch " + line + " for interval load visitor");
    }

    private void visitStatements(List<SelStatement> lines) {
        SelStatement[] patchedLines = new SelStatement[lines.size()];
        for (int index = lines.size() - 1; index >= 0; index--) {
            SelStatement line = lines.get(index);
            patchedLines[index] = visitStatement(line);
        }

        this.patchedLines = Arrays.asList(patchedLines);
    }

    public static List<SelStatement> fillInterval(
            SelVersion version,
            Interval originalInterval,
            Map<SelExprParam, GraphDataLoadRequest.Builder> extVarToRequest,
            List<SelStatement> lines)
    {
        IntervalLoadVisitor visitor = new IntervalLoadVisitor(version, originalInterval);
        visitor.visitStatements(lines);
        for (Map.Entry<SelExprParam, GraphDataLoadRequest.Builder> entry : extVarToRequest.entrySet()) {
            entry.getValue().setInterval(visitor.getInterval(entry.getKey().getName()));
        }
        return visitor.patchedLines;
    }
}
