package ru.yandex.solomon.expression.expr.func.analytical;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Stream;

import javax.annotation.ParametersAreNonnullByDefault;

import com.google.common.collect.ImmutableMap;

import ru.yandex.monlib.metrics.labels.Labels;
import ru.yandex.monlib.metrics.labels.LabelsBuilder;
import ru.yandex.solomon.expression.PositionRange;
import ru.yandex.solomon.expression.expr.EvalContext;
import ru.yandex.solomon.expression.expr.EvalContextImpl;
import ru.yandex.solomon.expression.expr.func.SelFunc;
import ru.yandex.solomon.expression.expr.func.SelFuncCategory;
import ru.yandex.solomon.expression.expr.func.SelFuncProvider;
import ru.yandex.solomon.expression.expr.func.SelFuncRegistry;
import ru.yandex.solomon.expression.expr.func.SelFunctionWithContext;
import ru.yandex.solomon.expression.type.SelTypeLambda;
import ru.yandex.solomon.expression.type.SelTypes;
import ru.yandex.solomon.expression.value.ArgsList;
import ru.yandex.solomon.expression.value.SelValue;
import ru.yandex.solomon.expression.value.SelValueGraphData;
import ru.yandex.solomon.expression.value.SelValueLambda;
import ru.yandex.solomon.expression.value.SelValueVector;
import ru.yandex.solomon.expression.version.SelVersion;

import static java.util.stream.Collectors.toList;

/**
 * <p>Group time-series by user-specified label names and apply on each group lambda function.
 * <p>Example usage
 * <pre>{@code
 *  let dataSet = http_requests_total{host="*"};
 *  let result = group_by_labels(dataSet, as_vector('application', 'group'), v -> group_lines('max', v)));
 * }</pre>
 *
 * @author Vladimir Gordiychuk
 */
@ParametersAreNonnullByDefault
public class SelFnGroupByLabels implements SelFuncProvider {
    private static SelValueVector evalScalar(EvalContext ctx, ArgsList args) {
        return eval(ctx, args, false);
    }

    private static SelValueVector evalVectored(EvalContext ctx, ArgsList args) {
        return eval(ctx, args, true);
    }

    private static SelValueVector eval(EvalContext ctx, ArgsList args, boolean lambdaIsVectored) {
        List<String> labelsToGroup = getLabels(args);
        SelValueLambda lambda = args.get(2).castToLambda();
        PositionRange lambdaRange = args.getRange(2);

        List<SelValueGraphData> targets = Stream.of(args.get(0).castToVector().valueArray())
            .map(SelValue::castToGraphData)
            .collect(toList());

        Map<Labels, List<SelValueGraphData>> grouped = new HashMap<>();

        for (var gd : targets) {
            var ngd = gd.getNamedGraphData();
            var groupKey = filterLabels(ngd.getLabels(), labelsToGroup);
            grouped.computeIfAbsent(groupKey, ignore -> new ArrayList<>()).add(gd);
        }

        List<SelValue> result = new ArrayList<>(grouped.size());
        for (var gds : grouped.values()) {
            SelValueVector vector = new SelValueVector(SelTypes.GRAPH_DATA, gds.toArray(new SelValueGraphData[0]));

            if (lambdaIsVectored) {
                SelValue[] lambdaResult = evalLambdaVectored(ctx.getVersion(), lambdaRange, lambda, vector).valueArray();
                result.addAll(Arrays.asList(lambdaResult));
            } else {
                result.add(evalLambda(ctx.getVersion(), lambdaRange, lambda, vector));
            }
        }

        return new SelValueVector(SelTypes.GRAPH_DATA, result.toArray(new SelValue[0]));
    }

    private static SelValue evalLambda(SelVersion version, PositionRange lambdaRange, SelValueLambda lambda, SelValueVector arg) {
        return lambda.body.eval(new EvalContextImpl(ImmutableMap.of(lambda.singleParamName(lambdaRange), arg), version))
            .castToGraphData();
    }

    private static SelValueVector evalLambdaVectored(SelVersion version, PositionRange lambdaRange, SelValueLambda lambda, SelValueVector arg) {
        return lambda.body.eval(new EvalContextImpl(ImmutableMap.of(lambda.singleParamName(lambdaRange), arg), version))
            .castToVector();
    }

    private static List<String> getLabels(ArgsList params) {
        SelValue labels = params.get(1);
        if (labels.type() == SelTypes.STRING) {
            return Collections.singletonList(labels.castToString().getValue());
        }

        return Stream.of(labels.castToVector().valueArray())
            .map(selValue -> selValue.castToString().getValue())
            .collect(toList());
    }

    private static Labels filterLabels(Labels labels, List<String> labelsToGroup) {
        if (labelsToGroup.isEmpty()) {
            return Labels.empty();
        }
        LabelsBuilder builder = Labels.builder(labelsToGroup.size());
        labels.forEach(l -> {
            if (labelsToGroup.contains(l.getKey())) {
                builder.add(l);
            }
        });
        return builder.build();
    }

    @Override
    public void provide(SelFuncRegistry registry) {
        SelTypeLambda scalarLabmbda = new SelTypeLambda(List.of(SelTypes.GRAPH_DATA_VECTOR), SelTypes.GRAPH_DATA);
        SelTypeLambda vectorLabmbda = new SelTypeLambda(List.of(SelTypes.GRAPH_DATA_VECTOR), SelTypes.GRAPH_DATA_VECTOR);

        provide(registry, SelVersion.GROUP_LINES_RETURN_VECTOR_2::before, scalarLabmbda, SelFnGroupByLabels::evalScalar);
        provide(registry, SelVersion.GROUP_LINES_RETURN_VECTOR_2::since, vectorLabmbda, SelFnGroupByLabels::evalVectored);
    }

    private void provide(
            SelFuncRegistry registry,
            Predicate<SelVersion> supportedVersions,
            SelTypeLambda acceptedLambda,
            SelFunctionWithContext handler)
    {
        for (var labelsType : List.of(SelTypes.STRING_VECTOR, SelTypes.STRING)) {
            registry.add(SelFunc.newBuilder()
                    .name("group_by_labels")
                    .help("Group time-series by user-specified label names and apply on each group lambda function")
                    .category(SelFuncCategory.COMBINE)
                    .args(SelTypes.GRAPH_DATA_VECTOR, labelsType, acceptedLambda)
                    .supportedVersions(supportedVersions)
                    .returnType(SelTypes.GRAPH_DATA_VECTOR)
                    .handler(handler)
                    .build());
        }
    }
}
