package ru.yandex.direct.ydb.builder.valuecreator;

import java.time.Duration;
import java.time.Instant;
import java.util.EnumMap;
import java.util.Map;
import java.util.function.Function;

import com.yandex.ydb.table.values.ListType;
import com.yandex.ydb.table.values.OptionalType;
import com.yandex.ydb.table.values.PrimitiveType;
import com.yandex.ydb.table.values.PrimitiveValue;
import com.yandex.ydb.table.values.Type;
import com.yandex.ydb.table.values.Value;

import static com.yandex.ydb.table.values.PrimitiveType.bool;
import static com.yandex.ydb.table.values.PrimitiveType.datetime;
import static com.yandex.ydb.table.values.PrimitiveType.int16;
import static com.yandex.ydb.table.values.PrimitiveType.int32;
import static com.yandex.ydb.table.values.PrimitiveType.int64;
import static com.yandex.ydb.table.values.PrimitiveType.int8;
import static com.yandex.ydb.table.values.PrimitiveType.interval;
import static com.yandex.ydb.table.values.PrimitiveType.json;
import static com.yandex.ydb.table.values.PrimitiveType.jsonDocument;
import static com.yandex.ydb.table.values.PrimitiveType.string;
import static com.yandex.ydb.table.values.PrimitiveType.uint16;
import static com.yandex.ydb.table.values.PrimitiveType.uint32;
import static com.yandex.ydb.table.values.PrimitiveType.uint64;
import static com.yandex.ydb.table.values.PrimitiveType.uint8;
import static com.yandex.ydb.table.values.PrimitiveType.utf8;
import static com.yandex.ydb.table.values.Type.Kind.LIST;
import static com.yandex.ydb.table.values.Type.Kind.OPTIONAL;

public class TypeValueMapper {

    private static final Map<PrimitiveType.Id, PrimitiveValueCreator> primitiveValueCreatorMap = getPrimitiveValueMap();
    private static final Map<Type.Kind, Function<Type, ValueCreator>> complexTypesValueCreatorMap =
            getComplexTypeCreator();

    public static ValueCreator getValueCreator(Type type) {
        switch (type.getKind()) {
            case PRIMITIVE:
                return primitiveValueCreatorMap.get(((PrimitiveType) type).getId());
            case LIST:
                return complexTypesValueCreatorMap.get(LIST).apply(type);
            case OPTIONAL:
                return complexTypesValueCreatorMap.get(OPTIONAL).apply(type);
            default:
                throw new IllegalStateException("Type " + type + " not yet supported");
        }
    }

    public static <T> PrimitiveValueCreator<T> getPrimitiveCreator(Type type) {
        switch (type.getKind()) {
            case PRIMITIVE:
                return primitiveValueCreatorMap.get(((PrimitiveType) type).getId());
            default:
                throw new IllegalStateException("Type " + type + " is not primitive");
        }
    }

    private static Map<Type.Kind, Function<Type, ValueCreator>> getComplexTypeCreator() {
        return Map.of(
                LIST, type -> new ListValueCreator((ListType) type),
                Type.Kind.OPTIONAL, type -> new OptionalValueCreator((OptionalType) type)
        );
    }

    private static Map<PrimitiveType.Id, PrimitiveValueCreator> getPrimitiveValueMap() {
        Map<PrimitiveType.Id, PrimitiveValueCreator> map = new EnumMap<>(PrimitiveType.Id.class);
        map.put(bool().getId(), new PrimitiveValueCreator<>(getOptionalValueCreator(bool(), PrimitiveValue::bool)));
        map.put(jsonDocument().getId(), new PrimitiveValueCreator<>(getOptionalValueCreator(jsonDocument(),
                PrimitiveValue::jsonDocument)));
        map.put(json().getId(), new PrimitiveValueCreator<>(getOptionalValueCreator(json(),
                PrimitiveValue::json)));
        map.put(string().getId(), new PrimitiveValueCreator<byte[]>(getOptionalValueCreator(string(),
                PrimitiveValue::string)));
        map.put(utf8().getId(), new PrimitiveValueCreator<>(getOptionalValueCreator(utf8(), PrimitiveValue::utf8)));
        map.put(uint8().getId(), new PrimitiveValueCreator<>(getOptionalValueCreator(uint8(), PrimitiveValue::uint8)));
        map.put(int8().getId(), new PrimitiveValueCreator<>(getOptionalValueCreator(int8(), PrimitiveValue::int8)));
        map.put(uint16().getId(), new PrimitiveValueCreator<>(getOptionalValueCreator(uint16(),
                PrimitiveValue::uint16)));
        map.put(int16().getId(), new PrimitiveValueCreator<>(getOptionalValueCreator(int16(),
                PrimitiveValue::int16)));
        map.put(uint32().getId(), new PrimitiveValueCreator<>(getOptionalValueCreator(uint32(),
                PrimitiveValue::uint32)));
        map.put(int32().getId(), new PrimitiveValueCreator<>(getOptionalValueCreator(int32(),
                PrimitiveValue::int32)));
        map.put(int64().getId(), new PrimitiveValueCreator<>(getOptionalValueCreator(int64(),
                PrimitiveValue::int64)));
        map.put(uint64().getId(), new PrimitiveValueCreator<>(getOptionalValueCreator(uint64(),
                PrimitiveValue::uint64)));
        map.put(datetime().getId(), new PrimitiveValueCreator<Instant>(getOptionalValueCreator(datetime(),
                PrimitiveValue::datetime)));
        map.put(interval().getId(), new PrimitiveValueCreator<Duration>(getOptionalValueCreator(interval(),
                PrimitiveValue::interval)));
        return map;
    }

    private static <T> Function<T, Value> getOptionalValueCreator(PrimitiveType primitiveType,
                                                                  Function<T, Value> notOptionalValueCreator) {
        return value -> {
            if (value == null) {
                return primitiveType.makeOptional().emptyValue();
            } else {
                return notOptionalValueCreator.apply(value);
            }
        };
    }
}
