package ru.yandex.solomon.expression.expr.func;

import java.util.Arrays;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import javax.annotation.ParametersAreNonnullByDefault;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Joiner;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;

import ru.yandex.solomon.expression.PositionRange;
import ru.yandex.solomon.expression.ast.AstIdent;
import ru.yandex.solomon.expression.exceptions.CompilerException;
import ru.yandex.solomon.expression.exceptions.InternalCompilerException;
import ru.yandex.solomon.expression.type.SelType;
import ru.yandex.solomon.expression.version.SelVersion;

/**
 * @author Vladimir Gordiychuk
 */
@ParametersAreNonnullByDefault
public class SelFuncRegistry {
    private final EnumMap<SelVersion, Table<String, List<SelType>, SelFunc>> registered;

    public SelFuncRegistry() {
        this.registered = new EnumMap<>(Arrays.stream(SelVersion.values())
                .collect(Collectors.toMap(Function.identity(), ignore -> HashBasedTable.create())));
    }

    public void add(SelFunc.Builder builder) {
        add(builder.build());
    }

    public void add(SelFunc func) {
        Arrays.stream(SelVersion.values())
                .filter(func.getSupportedVersions())
                .map(registered::get)
                .forEach(registry -> add(registry, func));
    }

    public void add(Table<String, List<SelType>, SelFunc> versionedRegistry, SelFunc func) {
        var prev = versionedRegistry.get(func.getName(), func.getArgsType());
        if (prev != null) {
            throw new InternalCompilerException(prev + " already registered");
        }
        versionedRegistry.put(func.getName(), func.getArgsType(), func);
    }

    public void add(SelFunc... funcs) {
        for (var func : funcs) {
            add(func);
        }
    }

    public void add(SelFuncProvider provider) {
        provider.provide(this);
    }

    public void ensureHasFunction(SelVersion version, AstIdent ident) {
        String name = ident.getIdent();
        if (!registered.get(version).containsRow(name)) {
            throw new CompilerException(ident.getRange(), "Unknown function: " + name);
        }
    }

    public SelFunc get(SelVersion version, AstIdent ident, List<SelType> args) {
        String name = ident.getIdent();
        var overloads = registered.get(version).row(name);
        if (overloads.isEmpty()) {
            throw new CompilerException(ident.getRange(), "Unknown function: " + name);
        }

        return resolveFunction(ident, overloads, args);
    }

    @VisibleForTesting
    public SelFunc get(SelVersion version, String name, List<SelType> args) {
        return get(version, new AstIdent(PositionRange.UNKNOWN, name), args);
    }

    private static SelFunc resolveFunction(AstIdent ident, Map<List<SelType>, SelFunc> overloads, List<SelType> args) {
        String name = ident.getIdent();
        var fn = overloads.get(args);
        if (fn == null && Set.copyOf(args).size() == 1) {
            fn = overloads.get(List.of(args.get(0)));
            if (fn != null && !fn.isVarArg()) {
                fn = null;
            }
        }

        if (fn == null) {
            String was = name + "(" + Joiner.on(", ").join(args) + ")";
            String expected = Joiner.on(", \n").join(overloads.values());
            String message = "Not valid function arguments, was \n" + was + ",\n but expected one of:\n" + expected;
            throw new CompilerException(ident.getRange(), message);
        }

        return fn;
    }

    public Stream<SelFunc> stream(SelVersion version) {
        return registered.get(version).values().stream();
    }
}
