package ru.yandex.mail.so.factors.fasttext;

import java.io.File;
import java.nio.file.Path;
import java.util.Collections;
import java.util.List;

import org.apache.http.concurrent.FutureCallback;

import ru.yandex.jni.fasttext.JniFastText;
import ru.yandex.jni.fasttext.JniFastTextException;
import ru.yandex.mail.so.factors.SoFactor;
import ru.yandex.mail.so.factors.SoFunctionInputs;
import ru.yandex.mail.so.factors.extractors.SoFactorsExtractor;
import ru.yandex.mail.so.factors.extractors.SoFactorsExtractorContext;
import ru.yandex.mail.so.factors.extractors.SoFactorsExtractorsRegistry;
import ru.yandex.mail.so.factors.types.SoFactorType;
import ru.yandex.mail.so.factors.types.StringSoFactorType;
import ru.yandex.parser.config.ConfigException;
import ru.yandex.parser.config.IniConfig;

public class FastTextEmbeddingExtractor implements SoFactorsExtractor {
    private static final List<SoFactorType<?>> INPUTS =
        Collections.singletonList(StringSoFactorType.STRING);
    private static final List<SoFactorType<?>> OUTPUTS =
        Collections.singletonList(
            FastTextEmbeddingSoFactorType.FAST_TEXT_EMBEDDING);

    private final JniFastText fastText;
    private final int dimension;

    public FastTextEmbeddingExtractor(
        final String name,
        final IniConfig config)
        throws ConfigException
    {
        ImmutableFastTextExtractorConfig fastTextExtractorConfig =
            new FastTextExtractorConfigBuilder(
                new FastTextExtractorConfigBuilder(config).build())
                .build();
        File stopWordList = fastTextExtractorConfig.stopWordList();
        Path stopWordListPath;
        if (stopWordList == null) {
            stopWordListPath = null;
        } else {
            stopWordListPath =
                stopWordList.getAbsoluteFile().toPath();
        }
        try {
            fastText = new JniFastText(
                fastTextExtractorConfig.model().getAbsoluteFile()
                    .toPath(),
                stopWordListPath);
        } catch (JniFastTextException e) {
            throw new ConfigException(
                "Failed to construct fast text for <" + name + '>',
                e);
        }
        dimension = fastText.getDimension();
    }

    @Override
    public void close() {
        fastText.close();
    }

    @Override
    public List<SoFactorType<?>> inputs() {
        return INPUTS;
    }

    @Override
    public List<SoFactorType<?>> outputs() {
        return OUTPUTS;
    }

    @Override
    public void extract(
        final SoFactorsExtractorContext context,
        final SoFunctionInputs inputs,
        final FutureCallback<? super List<SoFactor<?>>> callback)
    {
        String text = inputs.get(0, StringSoFactorType.STRING);
        if (text == null || text.isEmpty()) {
            callback.completed(NULL_RESULT);
            return;
        }

        float[] embedding;
        try {
            embedding = fastText.createDoc(text);
        } catch (JniFastTextException e) {
            callback.failed(e);
            return;
        }
        callback.completed(
            Collections.singletonList(
                FastTextEmbeddingSoFactorType.FAST_TEXT_EMBEDDING.createFactor(
                    new FastTextEmbedding(
                        dimension,
                        embedding))));
    }

    @Override
    public void registerInternals(final SoFactorsExtractorsRegistry registry)
        throws ConfigException
    {
        FastTextEmbeddingExtractorFactory.INSTANCE.registerInternals(registry);
    }
}

