package ru.yandex.infra.stage.protobuf;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.Map;

import com.google.protobuf.Descriptors;
import com.google.protobuf.Message;

import static com.google.common.base.Throwables.throwIfUnchecked;
import static java.util.stream.Collectors.toMap;

/**
 * Copypasted from iss repository https://bb.yandex-team.ru/projects/SEARCH_INFRA/repos/iss/browse
 * Automatically handle inheritance in protobuf in from of
 * message Message {
 * oneof derived {
 * // Derived cases here
 * }
 * }
 * <p>
 * Each conversion requires one call via MethodHandle and one call via Method inside protobuf implementation.
 * In case of performance problems we can extract method accessors as MethodHandles instead.
 * <p>
 * Throws {@link IllegalArgumentException} if cannot restore object from proto because oneof is empty
 */
public class OneofDerivedConverter<JavaType, ProtoType extends Message> {

    private static final String FROM_PROTO_METHOD_NAME = "fromProto", TO_PROTO_METHOD_NAME = "toProto";

    private final Descriptors.OneofDescriptor oneofFieldDescriptor;

    private final Map<Class<? extends JavaType>, Class<? extends Message>> javaToProto;
    private final Map<Class<? extends Message>, Class<? extends JavaType>> protoToJava;

    public OneofDerivedConverter(Descriptors.OneofDescriptor oneofFieldDescriptor,
                                 Map<Class<? extends JavaType>, Class<? extends Message>> javaToProto) {
        this.oneofFieldDescriptor = oneofFieldDescriptor;
        this.javaToProto = javaToProto;
        this.protoToJava = javaToProto.entrySet().stream().collect(toMap(Map.Entry::getValue, Map.Entry::getKey));
    }

    public JavaType fromProto(ProtoType proto, Converter converter) {
        return fromProto(proto, converter, null);
    }

    public JavaType fromProto(ProtoType proto, Converter converter, JavaType defaultValue) {
        try {
            if (defaultValue != null && !proto.hasOneof(oneofFieldDescriptor)) return defaultValue;

            Descriptors.FieldDescriptor fieldDescriptor = proto.getOneofFieldDescriptor(oneofFieldDescriptor);
            if (fieldDescriptor == null) {
                String errorMessage = String.format("No subfield is set for %s", oneofFieldDescriptor.getFullName());
                throw new IllegalArgumentException(errorMessage);
            }
            Object protoValue = proto.getField(fieldDescriptor);
            Class<? extends JavaType> javaClass = protoToJava.get(protoValue.getClass());
            if (javaClass != null) {
                MethodHandle handle = getConvertHandle(
                        FROM_PROTO_METHOD_NAME,
                        protoValue.getClass(),
                        javaClass,
                        converter.getClass()
                );
                return javaClass.cast(handle.invoke(converter, protoValue));
            } else {
                String message = String.format("Cannot convert object of %s from protobuf", protoValue.getClass());
                throw new IllegalArgumentException(message);
            }
        } catch (Throwable t) {
            throwIfUnchecked(t);
            throw new RuntimeException(t);
        }
    }

    public <ProtoBuilder extends Message.Builder> void setProtoField(ProtoBuilder builder, JavaType object,
                                                                     Converter converter) {
        try {
            Class<? extends Message> protoClass = javaToProto.get(object.getClass());
            if (protoClass == null) {
                String message = String.format("Cannot set object of %s as field of %s", object.getClass(),
                        builder.getClass());
                throw new IllegalArgumentException(message);
            }
            MethodHandle handle = getConvertHandle(
                    TO_PROTO_METHOD_NAME,
                    object.getClass(),
                    protoClass,
                    converter.getClass()
            );
            Message message = (Message) handle.invoke(converter, object);
            for (Descriptors.FieldDescriptor fieldDescriptor : oneofFieldDescriptor.getFields()) {
                if (message.getDescriptorForType() == fieldDescriptor.getMessageType()) {
                    builder.setField(fieldDescriptor, message);
                    return;
                }
            }
        } catch (Throwable t) {
            throwIfUnchecked(t);
            throw new RuntimeException(t);
        }
    }

    private static MethodHandle getConvertHandle(String methodName,
                                          Class<?> fromClass,
                                          Class<?> toClass,
                                          Class<?> converterClass) {
        try {
            return MethodHandles.lookup().findVirtual(
                    converterClass,
                    methodName,
                    MethodType.methodType(toClass, fromClass));
        } catch (ReflectiveOperationException e) {
            throw new RuntimeException(e);
        }
    }
}
