package ru.yandex.solomon.expression.expr.func.analytical;

import java.util.Arrays;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import javax.annotation.ParametersAreNonnullByDefault;

import org.apache.commons.lang3.mutable.MutableInt;

import ru.yandex.monlib.metrics.MetricType;
import ru.yandex.monlib.metrics.labels.Label;
import ru.yandex.monlib.metrics.labels.Labels;
import ru.yandex.monlib.metrics.labels.LabelsBuilder;
import ru.yandex.solomon.expression.NamedGraphData;
import ru.yandex.solomon.expression.PositionRange;
import ru.yandex.solomon.expression.exceptions.EvaluationException;
import ru.yandex.solomon.expression.exceptions.InternalCompilerException;
import ru.yandex.solomon.expression.expr.func.AggrSelFn;
import ru.yandex.solomon.expression.expr.func.LabelsUtil;
import ru.yandex.solomon.expression.expr.func.SelFunc;
import ru.yandex.solomon.expression.expr.func.SelFuncArgument;
import ru.yandex.solomon.expression.expr.func.SelFuncCategory;
import ru.yandex.solomon.expression.expr.func.SelFuncProvider;
import ru.yandex.solomon.expression.expr.func.SelFuncRegistry;
import ru.yandex.solomon.expression.type.SelTypes;
import ru.yandex.solomon.expression.value.ArgsList;
import ru.yandex.solomon.expression.value.SelValue;
import ru.yandex.solomon.expression.value.SelValueGraphData;
import ru.yandex.solomon.expression.value.SelValueVector;
import ru.yandex.solomon.expression.value.SelValueWithRange;
import ru.yandex.solomon.expression.version.SelVersion;
import ru.yandex.solomon.math.operation.reduce.CombineIterator;
import ru.yandex.solomon.math.protobuf.Aggregation;
import ru.yandex.solomon.model.point.AggrPoint;
import ru.yandex.solomon.model.point.column.TsColumn;
import ru.yandex.solomon.model.protobuf.MetricTypeConverter;
import ru.yandex.solomon.model.timeseries.AggrGraphDataArrayList;
import ru.yandex.solomon.model.timeseries.AggrGraphDataListIterator;
import ru.yandex.solomon.model.timeseries.GraphData;
import ru.yandex.solomon.model.timeseries.MetricTypeCasts;
import ru.yandex.solomon.model.timeseries.MetricTypeTransfers;
import ru.yandex.solomon.model.timeseries.aggregation.collectors.PointValueCollector;
import ru.yandex.solomon.model.timeseries.aggregation.collectors.PointValueCollectors;
import ru.yandex.solomon.util.collection.array.DoubleArrayView;

import static java.util.stream.Collectors.collectingAndThen;
import static java.util.stream.Collectors.groupingBy;
import static java.util.stream.Collectors.toList;
import static ru.yandex.solomon.expression.expr.func.SelFuncArgument.arg;

/**
 * <p>Apply aggregate function for values at the same time point over all lines
 *
 * <p>Example usage {@code group_lines('sum', group_by_time(15s, 'max', graphDataVector))}
 *
 * @author Vladimir Gordiychuk
 */
@ParametersAreNonnullByDefault
public class SelFnGroupLines implements SelFuncProvider {
    private static final List<String> AGGR_FOR_ALIASES = List.of(
            "max", "min", "avg", "sum"
    );

    private static SelValueGraphData groupSingle(
            List<NamedGraphData> values,
            PositionRange sourcesRange,
            AggrSelFn.Type aggregation,
            PositionRange aggrRange)
    {
        if (values.isEmpty()) {
            return new SelValueGraphData(GraphData.empty);
        }

        if (values.size() == 1 && isAbleSkip(aggregation)) {
            return new SelValueGraphData(values.get(0));
        }

        Labels labels = LabelsUtil.getCommonLabels(values);

        List<NamedGraphData> nonemptyValues = values.stream()
            .filter(ngd -> !ngd.getAggrGraphDataArrayList().isEmpty())
            .collect(toList());

        Set<MetricType> metricTypes = nonemptyValues.stream()
            .map(NamedGraphData::getDataType)
            .map(MetricTypeConverter::fromProto)
            .collect(Collectors.toSet());

        if (metricTypes.isEmpty()) {
            return new SelValueGraphData(NamedGraphData.newBuilder()
                    //.setType(???)
                    .setGraphData(AggrGraphDataArrayList.empty())
                    .setLabels(labels)
                    .build());
        }

        MetricType commonType = metricTypes.stream()
            .reduce(MetricTypeCasts::commonType)
            .get();

        if (commonType == MetricType.UNKNOWN) {
            throw new EvaluationException(sourcesRange, "Cannot group lines with different kinds: " + metricTypes);
        }

        if (!NATIVE_AGGREGATIONS.containsKey(aggregation)) {
            if (MetricTypeCasts.commonType(commonType, MetricType.DGAUGE) == MetricType.UNKNOWN) {
                throw new EvaluationException(aggrRange, "Aggregation " + aggregation +
                    " requires DGAUGE metric type, but got " + commonType);
            }
            commonType = MetricType.DGAUGE;
        }

        if (aggregation == AggrSelFn.Type.AVG) {
            // Force DGAUGE for AVG
            if (MetricTypeCasts.commonType(commonType, MetricType.DGAUGE) != MetricType.UNKNOWN) {
                commonType = MetricType.DGAUGE;
            }
        }

        final var commonMetricTypeProto = MetricTypeConverter.toNotNullProto(commonType);

        List<AggrGraphDataListIterator> iterators = nonemptyValues.stream()
            .map(ngd -> MetricTypeTransfers.of(
                    ngd.getDataType(),
                    commonMetricTypeProto,
                    ngd.getAggrGraphDataArrayList().iterator()))
            .collect(toList());

        int mask = iterators.stream()
            .map(AggrGraphDataListIterator::columnSetMask)
            .reduce(TsColumn.mask, (l, r) -> l | r);
        var it = CombineIterator.of(mask, iterators,
                getPointCollector(commonType, aggregation, aggrRange));
        AggrGraphDataArrayList result = AggrGraphDataArrayList.of(it);

        return new SelValueGraphData(NamedGraphData.newBuilder()
            .setType(commonType)
            .setGraphData(commonMetricTypeProto, result)
            .setLabels(labels)
            .build());
    }

    private static SelValueVector groupMany(
            List<NamedGraphData> values,
            PositionRange sourcesRange,
            AggrSelFn.Type aggregation,
            PositionRange aggrRange,
            Set<String> keys)
    {
        var grouped = values.stream()
            .collect(
                groupingBy(
                    source -> filterLabels(source.getLabels(), keys),
                    collectingAndThen(toList(), gds -> groupSingle(gds, sourcesRange, aggregation, aggrRange))))
            .values();

        HashMap<String, MutableInt> keyCounter = new HashMap<>();
        grouped.forEach(ngd -> {
            ngd.getNamedGraphData().getLabels().forEach(label -> {
                String key = label.getKey();
                if (!keys.contains(key)) {
                    keyCounter.computeIfAbsent(label.getKey(), ignore -> new MutableInt(0)).increment();
                }
            });
        });

        Set<String> spuriousKeys = keyCounter.entrySet().stream()
                .filter(e -> e.getValue().intValue() != grouped.size())
                .map(Map.Entry::getKey)
                .collect(Collectors.toSet());

        var groupedArray = grouped.stream()
                .map(gd -> dropKeys(gd, spuriousKeys))
                .toArray(SelValue[]::new);

        return new SelValueVector(SelTypes.GRAPH_DATA, groupedArray);
    }

    private static SelValueGraphData dropKeys(SelValueGraphData gd, Set<String> spuriousKeys) {
        var ngd = gd.getNamedGraphData();
        var filteredLabels = ngd.getLabels().stream()
                .filter(label -> !spuriousKeys.contains(label.getKey()))
                .toArray(Label[]::new);
        var labels = new LabelsBuilder(LabelsBuilder.SortState.SORTED, filteredLabels).build();
        return new SelValueGraphData(ngd.toBuilder().setLabels(labels).build());
    }

    private static boolean isAbleSkip(AggrSelFn.Type type) {
        return type != AggrSelFn.Type.COUNT && type != AggrSelFn.Type.RANDOM;
    }

    private static List<NamedGraphData> toSource(SelValue value) {
        return Stream.of(value.castToVector().valueArray())
            .map(v -> v.castToGraphData().getNamedGraphData())
            .collect(Collectors.toList());
    }

    private static final EnumMap<AggrSelFn.Type, Aggregation> NATIVE_AGGREGATIONS = new EnumMap<>(Map.of(
        AggrSelFn.Type.MAX, Aggregation.MAX,
        AggrSelFn.Type.MIN, Aggregation.MIN,
        AggrSelFn.Type.AVG, Aggregation.AVG,
        AggrSelFn.Type.SUM, Aggregation.SUM,
        AggrSelFn.Type.LAST, Aggregation.LAST,
        AggrSelFn.Type.COUNT, Aggregation.COUNT
    ));


    private static PointValueCollector getPointCollector(
            MetricType metricType,
            AggrSelFn.Type aggregation,
            PositionRange range)
    {
        Aggregation maybeNativeAggregation = NATIVE_AGGREGATIONS.get(aggregation);
        var metricTypeProto = MetricTypeConverter.toNotNullProto(metricType);
        if (maybeNativeAggregation != null) {
            return PointValueCollectors.of(metricTypeProto, maybeNativeAggregation);
        }
        // auxiliary collectors, DGAUGE only
        if (metricTypeProto != ru.yandex.solomon.model.protobuf.MetricType.DGAUGE) {
            throw new EvaluationException(range, "Aggregation " + aggregation +
                " is not supported for metric type " + metricTypeProto);
        }
        return new Collector(aggregation.getFunc());
    }

    private static AggrSelFn.Type toAggrType(SelValueWithRange valueWithRange) {
        SelValue value = valueWithRange.getValue();
        String fnName = value.castToString().getValue();
        return AggrSelFn.Type
            .byName(fnName)
            .orElseThrow(() -> new EvaluationException(valueWithRange.getRange(), "Unknown aggregation function: " + fnName));
    }

    private static Set<String> toLabelKeys(SelValue value) {
        if (value.type().isString()) {
            return Set.of(value.castToString().getValue());
        }

        return Stream.of(value.castToVector().valueArray())
            .map(v -> v.castToString().getValue())
            .collect(Collectors.toSet());
    }

    private static Labels filterLabels(Labels labels, Set<String> target) {
        LabelsBuilder builder = Labels.builder(target.size());
        labels.forEach(label -> {
            if (target.contains(label.getKey())) {
                builder.add(label);
            }
        });
        return builder.build();
    }

    @Override
    public void provide(SelFuncRegistry registry) {
        String[] availableAggr = Stream.of(AggrSelFn.Type.values())
            .map(type -> type.name().toLowerCase())
            .toArray(String[]::new);

        SelFuncArgument aggregation = arg("aggregation")
                .type(SelTypes.STRING)
                .help("function used to aggregate timeseries")
                .availableValues(availableAggr)
                .build();

        SelFuncArgument source = arg("source")
                .type(SelTypes.GRAPH_DATA_VECTOR)
                .build();

        SelFuncArgument keyArg = arg("key")
                .type(SelTypes.STRING)
                .help("label key used for grouping, for example 'host'")
                .build();

        SelFuncArgument keysArg = arg("keys")
                .type(SelTypes.STRING_VECTOR)
                .help("label keys used for grouping, for example ['host', 'status']")
                .build();

        registry.add(SelFunc.newBuilder()
            .name("group_lines")
            .help("Apply aggregate function for values at the same time point over all graphs")
            .category(SelFuncCategory.DEPRECATED)
            .args(aggregation, source)
            .supportedVersions(SelVersion.GROUP_LINES_RETURN_VECTOR_2::before)
            .returnType(SelTypes.GRAPH_DATA)
            .handler(SelFnGroupLines::groupAll)
            .build());

        registry.add(SelFunc.newBuilder()
            .name("group_lines")
            .help("Apply aggregate function for values at the same time point over all graphs")
            .category(SelFuncCategory.COMBINE)
            .args(aggregation, source)
            .supportedVersions(SelVersion.GROUP_LINES_RETURN_VECTOR_2::since)
            .returnType(SelTypes.GRAPH_DATA_VECTOR)
            .handler(SelFnGroupLines::groupAllVectored)
            .build());

        for (String aggr : AGGR_FOR_ALIASES) {
            var func = AggrSelFn.Type
                    .byName(aggr)
                    .orElseThrow(() -> new InternalCompilerException("Unknown aggregation function: " + aggr));
            registry.add(SelFunc.newBuilder()
                    .name("series_" + aggr)
                    .help("Apply aggregate function " + aggr + " for values at the same time point over all graphs")
                    .category(SelFuncCategory.COMBINE)
                    .args(source)
                    .returnType(SelTypes.GRAPH_DATA_VECTOR)
                    .handler(args -> {
                        var sources = toSource(args.get(0));
                        PositionRange sourceRange = args.getRange(0);

                        return groupSingle(sources, sourceRange, func, args.getCallRange()).asSingleElementVector();
                    })
                    .build());
        }

        registry.add(SelFunc.newBuilder()
            .name("group_lines")
            .help("Apply aggregate function for values at the same time over graphs grouped by specified label")
            .category(SelFuncCategory.DEPRECATED)
            .args(aggregation, keyArg, source)
            .returnType(SelTypes.GRAPH_DATA_VECTOR)
            .handler(SelFnGroupLines::groupByLabels)
            .build());

        registry.add(SelFunc.newBuilder()
            .name("group_lines")
            .help("Apply aggregate function for values at the same time over graphs grouped by specified labels")
            .category(SelFuncCategory.DEPRECATED)
            .args(aggregation, keysArg, source)
            .returnType(SelTypes.GRAPH_DATA_VECTOR)
            .handler(SelFnGroupLines::groupByLabels)
            .build());

        for (String aggr : AGGR_FOR_ALIASES) {
            var func = AggrSelFn.Type
                    .byName(aggr)
                    .orElseThrow(() -> new InternalCompilerException("Unknown aggregation function: " + aggr));
            registry.add(SelFunc.newBuilder()
                    .name("series_" + aggr)
                    .help("Apply aggregate function " + aggr + "for values at the same time over graphs grouped by specified label")
                    .category(SelFuncCategory.COMBINE)
                    .args(keyArg, source)
                    .returnType(SelTypes.GRAPH_DATA_VECTOR)
                    .handler(args -> {
                        var keys = toLabelKeys(args.get(0));
                        var sources = toSource(args.get(1));
                        PositionRange sourcesRange = args.getRange(1);
                        return groupMany(sources, sourcesRange, func, args.getCallRange(), keys);
                    })
                    .build());

            registry.add(SelFunc.newBuilder()
                    .name("series_" + aggr)
                    .help("Apply aggregate function " + aggr + "for values at the same time over graphs grouped by specified label")
                    .category(SelFuncCategory.COMBINE)
                    .args(keysArg, source)
                    .returnType(SelTypes.GRAPH_DATA_VECTOR)
                    .handler(args -> {
                        var keys = toLabelKeys(args.get(0));
                        var sources = toSource(args.get(1));
                        PositionRange sourcesRange = args.getRange(1);
                        return groupMany(sources, sourcesRange, func, args.getCallRange(), keys);
                    })
                    .build());
        }
    }

    private static SelValueGraphData groupAll(ArgsList args) {
        var aggr = toAggrType(args.getWithRange(0));
        PositionRange aggrRange = args.getRange(0);
        var sources = toSource(args.get(1));
        PositionRange sourceRange = args.getRange(1);

        return groupSingle(sources, sourceRange, aggr, aggrRange);
    }

    private static SelValueVector groupAllVectored(ArgsList args) {
        var aggr = toAggrType(args.getWithRange(0));
        PositionRange aggrRange = args.getRange(0);
        var sources = toSource(args.get(1));
        PositionRange sourceRange = args.getRange(1);

        return groupSingle(sources, sourceRange, aggr, aggrRange).asSingleElementVector();
    }

    private static SelValueVector groupByLabels(ArgsList args) {
        var aggr = toAggrType(args.getWithRange(0));
        PositionRange aggrRange = args.getRange(0);
        var keys = toLabelKeys(args.get(1));
        var sources = toSource(args.get(2));
        PositionRange sourcesRange = args.getRange(2);
        return groupMany(sources, sourcesRange, aggr, aggrRange, keys);
    }

    private static class Collector implements PointValueCollector {
        private final AggrSelFn fn;
        private double[] values;
        private int size;

        public Collector(AggrSelFn fn) {
            this.fn = fn;
            this.size = 0;
            this.values = new double[100];
        }

        @Override
        public void reset() {
            size = 0;
        }

        @Override
        public void append(AggrPoint point) {
            if (size == values.length) {
                values = Arrays.copyOf(values, size * 2);
            }

            values[size++] = point.getValueDivided();
        }

        @Override
        public boolean compute(AggrPoint point) {
            point.setValue(fn.evalArrayView(new DoubleArrayView(values, 0, size)));
            return true;
        }
    }
}
