package ru.yandex.direct.ydb.builder.expression;

import java.util.ArrayList;
import java.util.List;

import com.yandex.ydb.table.values.PrimitiveType;
import com.yandex.ydb.table.values.PrimitiveValue;
import com.yandex.ydb.table.values.Type;

import ru.yandex.direct.ydb.builder.YqlWithParams;
import ru.yandex.direct.ydb.builder.valuecreator.ValueCreator;

public class AggregateExpression<T> implements Expression<T> {


    private static final String MAX = "MAX";
    private static final String MIN = "MIN";
    private static final String AVG = "AVG";
    private static final String SUM = "SUM";
    private static final String COUNT = "COUNT";
    private static final String COUNT_ALL = "COUNT(*)";
    private final Type type;
    private final ValueCreator<T> valueCreator;
    private final List<YqlWithParams> yqlWithParamsList = new ArrayList<>();

    private AggregateExpression(String template, Expression<?> expression, Type type, ValueCreator<T> valueCreator) {
        this.type = type;
        this.valueCreator = valueCreator;
        yqlWithParamsList.add(new YqlWithParams(template + "("));
        yqlWithParamsList.addAll(expression.getYqlWithParamsList());
        yqlWithParamsList.add(new YqlWithParams(")"));
    }

    private AggregateExpression(String template, Type type, ValueCreator<T> valueCreator) {
        this.type = type;
        this.valueCreator = valueCreator;
        yqlWithParamsList.add(new YqlWithParams(template));
    }

    @Override
    public Type getType() {
        return type;
    }

    @Override
    public ValueCreator<T> getValueCreator() {
        return valueCreator;
    }

    @Override
    public List<YqlWithParams> getYqlWithParamsList() {
        return yqlWithParamsList;
    }

    public static <T extends Number> AggregateExpression<Double> avg(Expression<T> expression) {
        return new AggregateExpression<>(AVG, expression, PrimitiveType.float64(), PrimitiveValue::float64);
    }

    public static AggregateExpression<Long> sumLong(Expression<Long> expression) {
        return new AggregateExpression<>(SUM, expression, PrimitiveType.uint64(), PrimitiveValue::uint64);
    }

    public static AggregateExpression<Long> sumShort(Expression<Short> expression) {
        return new AggregateExpression<>(SUM, expression, PrimitiveType.uint64(), PrimitiveValue::uint64);
    }

    public static AggregateExpression<Long> sumInt(Expression<Integer> expression) {
        return new AggregateExpression<>(SUM, expression, PrimitiveType.uint64(), PrimitiveValue::uint64);
    }

    public static AggregateExpression<Long> sumByte(Expression<Byte> expression) {
        return new AggregateExpression<>(SUM, expression, PrimitiveType.uint64(), PrimitiveValue::uint64);
    }

    public static AggregateExpression<Float> sumFloat(Expression<Float> expression) {
        return new AggregateExpression<>(SUM, expression, PrimitiveType.float32(), PrimitiveValue::float32);
    }

    public static AggregateExpression<Double> sumDouble(Expression<Double> expression) {
        return new AggregateExpression<>(SUM, expression, PrimitiveType.float64(), PrimitiveValue::float64);
    }

    public static <T> AggregateExpression<T> min(Expression<T> expression) {
        return new AggregateExpression<>(MIN, expression, expression.getType(), expression.getValueCreator());
    }

    public static <T> AggregateExpression<T> max(Expression<T> expression) {
        return new AggregateExpression<>(MAX, expression, expression.getType(), expression.getValueCreator());
    }

    public static <T> AggregateExpression<Long> count(Expression<T> expression) {
        return new AggregateExpression<>(COUNT, expression, PrimitiveType.uint64(), PrimitiveValue::uint64);
    }

    public static AggregateExpression<Long> count() {
        return new AggregateExpression<>(COUNT_ALL, PrimitiveType.uint64(), PrimitiveValue::uint64);
    }
}
