package ru.yandex.solomon.yasm.expression.grammar.functions;

import java.util.List;
import java.util.Map;

import javax.annotation.ParametersAreNonnullByDefault;

import ru.yandex.solomon.expression.PositionRange;
import ru.yandex.solomon.expression.ast.AstCall;
import ru.yandex.solomon.expression.ast.AstIdent;
import ru.yandex.solomon.expression.ast.AstValueDouble;
import ru.yandex.solomon.expression.ast.AstValueString;
import ru.yandex.solomon.yasm.expression.ast.YasmAst;
import ru.yandex.solomon.yasm.expression.ast.YasmAstIdent;
import ru.yandex.solomon.yasm.expression.ast.YasmAstNumber;
import ru.yandex.solomon.yasm.expression.grammar.ExpressionWithConstants;
import ru.yandex.solomon.yasm.expression.grammar.FunctionRenderer;
import ru.yandex.solomon.yasm.expression.grammar.YasmSelRenderer;

import static ru.yandex.solomon.yasm.expression.grammar.functions.Const.makeLineArg;

/**
 * @author Ivan Tsybulin
 */
@ParametersAreNonnullByDefault
public class Quant implements FunctionRenderer {

    @Override
    public String name() {
        return "quant";
    }

    @Override
    public ExpressionWithConstants render(YasmSelRenderer renderer, List<YasmAst> args) {
        ExpressionWithConstants signal = renderer.visit(args.get(0));

        if (args.size() < 1 || args.size() > 3) {
            throw new IllegalArgumentException("Expected from 1 to 3 arguments for quant call, got: " + args.size());
        }

        if (args.size() == 1) {
            return quant(signal, 50d);
        }

        double percentileLevel = parsePercentileLevel(args.get(1));
        var result = quant(signal, percentileLevel);

        if (args.size() == 2) {
            return result;
        } else {
            ExpressionWithConstants fallback = makeLineArg(renderer, args.get(2));
            return ExpressionWithConstants.series(
                    new AstCall(PositionRange.UNKNOWN,
                            new AstIdent(PositionRange.UNKNOWN, "fallback"),
                            List.of(result.expression, fallback.expression)
                    ),
                    List.of(result.constants, fallback.constants)
            );
        }
    }

    private static ExpressionWithConstants quant(ExpressionWithConstants signal, double v) {
        return signal.mapToSeries(expression -> new AstCall(PositionRange.UNKNOWN,
                        new AstIdent(PositionRange.UNKNOWN, "histogram_percentile"),
                        List.of(
                                new AstValueDouble(PositionRange.UNKNOWN, v),
                                new AstValueString(PositionRange.UNKNOWN, ""),
                                expression
                        )
                ));
    }

    private static final Map<String, Double> NAMED_QUANTILES = Map.of(
            "max", 100d,
            "min", 0d,
            "med", 50d
    );

    private static double parsePercentileLevel(YasmAst ast) {
        if (ast instanceof YasmAstNumber) {
            YasmAstNumber astNumber = (YasmAstNumber) ast;
            if (astNumber.getToken().indexOf('.') == -1) {
                return 100d * Double.parseDouble("0." + astNumber.getToken());
            } else {
                return Math.max(0d, Math.min(100d, 100d * astNumber.getValue()));
            }
        }
        if (ast instanceof YasmAstIdent) {
            YasmAstIdent astIdent = (YasmAstIdent) ast;
            Double level = NAMED_QUANTILES.get(astIdent.getIdent());
            if (level == null) {
                throw new IllegalArgumentException("Unknown named level in quant: " + astIdent.getIdent());
            }
            return level;
        }
        throw new IllegalArgumentException("Unsupported ast type for level in quant: " + ast.getType());
    }
}
