package ru.yandex.solomon.expression.expr.func.analytical.rank;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Stream;

import com.google.common.collect.ImmutableSet;

import ru.yandex.bolts.collection.Tuple2;
import ru.yandex.solomon.expression.exceptions.EvaluationException;
import ru.yandex.solomon.expression.exceptions.InternalCompilerException;
import ru.yandex.solomon.expression.expr.func.AggrSelFn;
import ru.yandex.solomon.expression.expr.func.SelFunc;
import ru.yandex.solomon.expression.expr.func.SelFuncCategory;
import ru.yandex.solomon.expression.expr.func.SelFuncProvider;
import ru.yandex.solomon.expression.expr.func.SelFuncRegistry;
import ru.yandex.solomon.expression.type.SelTypes;
import ru.yandex.solomon.expression.value.ArgsList;
import ru.yandex.solomon.expression.value.SelValue;
import ru.yandex.solomon.expression.value.SelValueVector;
import ru.yandex.solomon.expression.value.SelValueWithRange;

import static ru.yandex.solomon.expression.expr.func.SelFuncArgument.arg;

/**
 * @author Vladimir Gordiychuk
 */
public class SelFnRank implements SelFuncProvider {
    private static final List<String> AGGR_FOR_ALIAS = List.of(
            "max", "min", "avg", "sum", "count", "last"
    );

    private static final Set<String> ALIASES;

    static {
        HashSet<String> aliases = new HashSet<>();
        aliases.add("top");
        aliases.add("bottom");
        for (String aggr : AGGR_FOR_ALIAS) {
            aliases.add("top_" + aggr);
            aliases.add("bottom_" + aggr);
        }
        ALIASES = ImmutableSet.copyOf(aliases);
    }

    public static boolean matches(String funcName) {
        return ALIASES.contains(funcName);
    }

    public static SelValue calculate(RankType type, int limit, AggrSelFn aggrFn, SelValueVector vector) {
        if (vector.length() <= limit) {
            return vector;
        }

        List<Tuple2<SelValue, Double>> rankedList = new ArrayList<>(vector.length());
        for (SelValue item : vector.valueArray()) {
            double value = execAggr(aggrFn, item);
            rankedList.add(Tuple2.tuple(item, value));
        }
        SelValue[] result = rankedList.stream()
                .sorted(Comparator.comparing(Tuple2::get2, type.getComparator()))
                .map(Tuple2::get1)
                .limit(limit)
                .toArray(SelValue[]::new);

        return new SelValueVector(SelTypes.GRAPH_DATA, result);
    }

    public static SelValue calculate(RankType type, ArgsList params) {
        int limit = (int) params.get(0).castToScalar().getValue();
        AggrSelFn aggrFn = aggrFunction(params.getWithRange(1));
        SelValueVector vector = params.get(2).castToVector();

        return calculate(type, limit, aggrFn, vector);
    }

    public static SelValue calculate(RankType type, AggrSelFn aggrFn, ArgsList params) {
        int limit = (int) params.get(0).castToScalar().getValue();
        SelValueVector vector = params.get(1).castToVector();

        return calculate(type, limit, aggrFn, vector);
    }

    private static double execAggr(AggrSelFn aggrFn, SelValue arg) {
        return aggrFn.evalGraphData(arg.castToGraphData().getNamedGraphData());
    }

    private static AggrSelFn aggrFunction(SelValueWithRange function) {
        String functionName = function.getValue().castToString().getValue();
        return AggrSelFn.Type.byName(functionName)
            .map(AggrSelFn.Type::getFunc)
            .orElseThrow(() -> new EvaluationException(function.getRange(), "Aggregate function " + functionName + " does not exist"));
    }

    @Override
    public void provide(SelFuncRegistry registry) {
        String[] aggregations = Stream.of(AggrSelFn.Type.values())
            .map(type -> type.name().toLowerCase())
            .toArray(String[]::new);

        registry.add(SelFunc.newBuilder()
            .name("top")
            .help("Take limited number of timeseries with higher aggregation value")
            .args(
                arg("limit").type(SelTypes.DOUBLE).help("limit of time series"),
                arg("aggregation").type(SelTypes.STRING).help("function uses to calculate aggregated value for sort").availableValues(aggregations),
                arg("source").type(SelTypes.GRAPH_DATA_VECTOR))
            .category(SelFuncCategory.DEPRECATED)
            .returnType(SelTypes.GRAPH_DATA_VECTOR)
            .handler(args -> calculate(RankType.TOP, args))
            .build());

        registry.add(SelFunc.newBuilder()
            .name("bottom")
            .help("Take limited number of timeseries with lower aggregation value")
            .args(
                arg("limit").type(SelTypes.DOUBLE).help("limit of time series"),
                arg("aggregation").type(SelTypes.STRING).help("function uses to calculate aggregated value for sort").availableValues(aggregations),
                arg("source").type(SelTypes.GRAPH_DATA_VECTOR))
            .category(SelFuncCategory.DEPRECATED)
            .returnType(SelTypes.GRAPH_DATA_VECTOR)
            .handler(args -> calculate(RankType.BOTTOM, args))
            .build());

        for (String aggr : AGGR_FOR_ALIAS) {
            var func = AggrSelFn.Type.byName(aggr)
                    .map(AggrSelFn.Type::getFunc)
                    .orElseThrow(() -> new InternalCompilerException("Aggregate function " + aggr + " does not exist"));

            registry.add(SelFunc.newBuilder()
                    .name("top_" + aggr)
                    .help("Take limited number of timeseries with higher " + aggr + " value")
                    .args(
                            arg("limit").type(SelTypes.DOUBLE).help("limit of time series"),
                            arg("source").type(SelTypes.GRAPH_DATA_VECTOR))
                    .category(SelFuncCategory.RANK)
                    .returnType(SelTypes.GRAPH_DATA_VECTOR)
                    .handler(args -> calculate(RankType.TOP, func, args))
                    .build());

            registry.add(SelFunc.newBuilder()
                    .name("bottom_" + aggr)
                    .help("Take limited number of timeseries with lower " + aggr + " value")
                    .category(SelFuncCategory.RANK)
                    .args(
                            arg("limit").type(SelTypes.DOUBLE).help("limit of time series"),
                            arg("source").type(SelTypes.GRAPH_DATA_VECTOR))
                    .returnType(SelTypes.GRAPH_DATA_VECTOR)
                    .handler(args -> calculate(RankType.BOTTOM, func, args))
                    .build());
        }
    }
}
