package ru.yandex.antifraud.lua_context_manager;

import java.util.Map;

import javax.annotation.Nonnull;

import core.org.luaj.vm2.LuaDouble;
import core.org.luaj.vm2.LuaError;
import core.org.luaj.vm2.LuaTable;
import core.org.luaj.vm2.LuaValue;
import core.org.luaj.vm2.lib.ThreeArgFunction;

import ru.yandex.jni.catboost.JniCatboostException;
import ru.yandex.jni.catboost.JniCatboostFeatures;
import ru.yandex.jni.catboost.JniCatboostModel;

public enum CatboostTuner {
    INSTANCE;

    public void tuneContext(LuaValue context, @Nonnull Map<String, JniCatboostModel> models) {
        context.set("calcModel", new CalcFunction(models));
    }

    private static class CalcFunction extends ThreeArgFunction {
        @Nonnull
        private final Map<String, JniCatboostModel> models;

        private CalcFunction(@Nonnull Map<String, JniCatboostModel> models) {
            this.models = models;
        }


        @Override
        public LuaValue call(LuaValue modelName, LuaValue numericFeatures, LuaValue categoricalFeatures) {
            final JniCatboostModel model = models.get(modelName.checkjstring());
            final LuaTable numericFeaturesTable = numericFeatures.checktable();
            final LuaTable categoricalFeaturesTable = categoricalFeatures.checktable();

            if (model == null) {
                throw new LuaError("there is no model " + modelName + " where models: " + String.join(",",
                        models.keySet()));
            }

            try {
                final JniCatboostFeatures features = new JniCatboostFeatures();

                for (LuaValue feature : numericFeaturesTable.keys()) {
                    features.setNumericFeature(feature.checkjstring(), numericFeaturesTable.get(feature).tofloat());
                }

                for (LuaValue feature : categoricalFeaturesTable.keys()) {
                    features.setCategoricalFeature(feature.checkjstring(),
                            categoricalFeaturesTable.get(feature).checkjstring());
                }

                final double[] result = model.calc(features);

                final LuaValue luaResult = new LuaTable();

                for (int i = 0; i < result.length; i++) {
                    luaResult.set(i + 1, LuaDouble.valueOf(result[i]));
                }

                return luaResult;
            } catch (JniCatboostException e) {
                throw new LuaError(e);
            }
        }
    }
}
