package ru.yandex.mail.so.factors.hnsw;

import com.github.jelmerk.knn.DistanceFunction;

import ru.yandex.jni.fasttext.JniFastText;
import ru.yandex.jni.fasttext.JniFastTextException;
import ru.yandex.mail.so.factors.fasttext.FastTextEmbedding;

public enum WmdDistance implements DistanceFunction<FastTextEmbedding, Float> {
    RELAXED {
        @Override
        public Float distance(FastTextEmbedding u, FastTextEmbedding v) {
            if (u.dimension() == v.dimension()) {
                try {
                    return JniFastText.relaxedWmd(
                        u.embedding(),
                        v.embedding(),
                        u.dimension());
                } catch (JniFastTextException e) {
                }
            }
            return 1f;
        }
    },
    GREED {
        @Override
        public Float distance(FastTextEmbedding u, FastTextEmbedding v) {
            if (u.dimension() == v.dimension()) {
                try {
                    return JniFastText.greedWmd(
                        u.embedding(),
                        v.embedding(),
                        u.dimension());
                } catch (JniFastTextException e) {
                }
            }
            return 1f;
        }
    };
}

