package ru.yandex.mail.so.factors.fasttext;

import ru.yandex.json.dom.BasicContainerFactory;
import ru.yandex.json.dom.JsonDouble;
import ru.yandex.json.dom.JsonList;
import ru.yandex.json.dom.JsonLong;
import ru.yandex.json.dom.JsonMap;
import ru.yandex.mail.so.factors.FactorsAccessViolationHandler;
import ru.yandex.mail.so.factors.SoFactorFieldAccessorBase;
import ru.yandex.mail.so.factors.types.JsonMapSoFactorType;
import ru.yandex.mail.so.factors.types.JsonObjectSoFactorType;
import ru.yandex.mail.so.factors.types.SoFactorType;

public enum FastTextEmbeddingSoFactorType
    implements SoFactorType<FastTextEmbedding>
{
    FAST_TEXT_EMBEDDING;

    @Override
    public FastTextEmbedding cast(final Object value) {
        if (value instanceof FastTextEmbedding) {
            return (FastTextEmbedding) value;
        } else {
            return null;
        }
    }

    @Override
    public SoFactorFieldAccessorBase fieldAccessorFor(final String fieldName) {
        switch (fieldName) {
            case "__json_map__":
                return new JsonAccessor(
                    this,
                    JsonMapSoFactorType.JSON_MAP,
                    '.' + fieldName);
            case "__json_object__":
                return new JsonAccessor(
                    this,
                    JsonObjectSoFactorType.JSON_OBJECT,
                    '.' + fieldName);
            default:
                return null;
        }
    }

    public static class JsonAccessor extends SoFactorFieldAccessorBase {
        public JsonAccessor(
            final SoFactorType<?> variableType,
            final SoFactorType<?> fieldType,
            final String stringValue)
        {
            super(variableType, fieldType, stringValue);
        }

        @Override
        public JsonMap extractField(
            final Object value,
            final FactorsAccessViolationHandler accessViolationHandler)
        {
            if (value instanceof FastTextEmbedding) {
                FastTextEmbedding embedding = (FastTextEmbedding) value;
                float[] embeddingData = embedding.embedding();
                JsonList list =
                    new JsonList(
                        BasicContainerFactory.INSTANCE,
                        embeddingData.length);
                for (float v: embeddingData) {
                    list.add(new JsonDouble(v));
                }

                JsonMap map = new JsonMap(BasicContainerFactory.INSTANCE, 3);
                map.put("dimension", new JsonLong(embedding.dimension()));
                map.put("embedding", list);
                return map;
            }
            return null;
        }
    }
}

