package ru.yandex.mail.so.factors.dssm;

import ru.yandex.json.dom.BasicContainerFactory;
import ru.yandex.json.dom.JsonDouble;
import ru.yandex.json.dom.JsonList;
import ru.yandex.mail.so.factors.FactorsAccessViolationHandler;
import ru.yandex.mail.so.factors.SoFactorFieldAccessorBase;
import ru.yandex.mail.so.factors.types.JsonListSoFactorType;
import ru.yandex.mail.so.factors.types.JsonObjectSoFactorType;
import ru.yandex.mail.so.factors.types.SoFactorType;

public enum DssmEmbeddingSoFactorType implements SoFactorType<float[]> {
    DSSM_EMBEDDING;

    @Override
    public float[] cast(final Object value) {
        if (value instanceof float[]) {
            return (float[]) value;
        } else {
            return null;
        }
    }

    @Override
    public SoFactorFieldAccessorBase fieldAccessorFor(final String fieldName) {
        switch (fieldName) {
            case "__json_list__":
                return new JsonAccessor(
                    this,
                    JsonListSoFactorType.JSON_LIST,
                    '.' + fieldName);
            case "__json_object__":
                return new JsonAccessor(
                    this,
                    JsonObjectSoFactorType.JSON_OBJECT,
                    '.' + fieldName);
            default:
                return null;
        }
    }

    public static class JsonAccessor extends SoFactorFieldAccessorBase {
        public JsonAccessor(
            final SoFactorType<?> variableType,
            final SoFactorType<?> fieldType,
            final String stringValue)
        {
            super(variableType, fieldType, stringValue);
        }

        @Override
        public JsonList extractField(
            final Object value,
            final FactorsAccessViolationHandler accessViolationHandler)
        {
            if (value instanceof float[]) {
                float[] embedding = (float[]) value;
                JsonList list =
                    new JsonList(
                        BasicContainerFactory.INSTANCE,
                        embedding.length);
                for (float v: embedding) {
                    list.add(new JsonDouble(v));
                }
                return list;
            }
            return null;
        }
    }
}

