package ru.yandex.travel.workflow.base;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.protobuf.Message;
import com.google.protobuf.ProtocolMessageEnum;
import lombok.extern.slf4j.Slf4j;

import ru.yandex.misc.ExceptionUtils;
import ru.yandex.travel.workflow.MessagingUtils;
import ru.yandex.travel.workflow.StateContext;
import ru.yandex.travel.workflow.StatefulWorkflowEventHandler;
import ru.yandex.travel.workflow.entities.WorkflowEntity;

@Slf4j
public class AnnotatedStatefulWorkflowEventHandler<S extends ProtocolMessageEnum, E extends WorkflowEntity<S>>
        implements StatefulWorkflowEventHandler<S, E> {
    private final Map<Class<? extends Message>, BiConsumer<Message, StateContext<S, E>>> handlerMethods;
    private final Set<Class<? extends Message>> ignoredEvents;

    public AnnotatedStatefulWorkflowEventHandler() {
        handlerMethods = collectHandlerMethods();
        ignoredEvents = collectIgnoredEvents();
    }

    private Set<Class<? extends Message>> collectIgnoredEvents() {
        ImmutableSet.Builder<Class<? extends Message>> ignoredEventsBuilder = ImmutableSet.builder();
        if (this.getClass().isAnnotationPresent(IgnoreEvents.class)) {
            Arrays.stream(this.getClass().getAnnotation(IgnoreEvents.class).types()).forEach(ignoredEventsBuilder::add);
        }
        return ignoredEventsBuilder.build();
    }

    @SuppressWarnings("unchecked")
    private Map<Class<? extends Message>, BiConsumer<Message, StateContext<S, E>>> collectHandlerMethods() {
        Map<Class<? extends Message>, BiConsumer<Message, StateContext<S, E>>> handlerMethods;
        AnnotatedStatefulWorkflowEventHandler obj = this;
        ImmutableMap.Builder<Class<? extends Message>, BiConsumer<Message, StateContext<S, E>>> builder = ImmutableMap.builder();
        Arrays.stream(this.getClass().getMethods()).filter(m -> m.isAnnotationPresent(HandleEvent.class)).forEach(m -> {
            BiConsumer<Message, StateContext<S, E>> handler = (event, stateContext) -> {
                try {
                    m.invoke(obj, event, stateContext);
                } catch (IllegalAccessException e) {
                    throw new RuntimeException("Unable to access handler method", e);
                } catch (InvocationTargetException e) {
                    throw ExceptionUtils.throwException(e.getCause() != null ? e.getCause() : e);
                }
            };
            validateHandlerParams(m);
            Class<? extends Message> messageType = m.getAnnotation(HandleEvent.class).value();
            if (messageType.equals(MessageTypeNotSet.class)) {
                // no explicit type, derive it from the parameters list
                messageType = (Class<? extends Message>) m.getParameterTypes()[0];
            }
            builder.put(messageType, handler);
        });
        handlerMethods = builder.build();
        return handlerMethods;
    }

    static void validateHandlerParams(Method handlerMethod) {
        Class<?>[] parameterTypes = handlerMethod.getParameterTypes();
        Preconditions.checkArgument(parameterTypes.length == 2,
                "Exactly 2 handler method parameters are expected but got %s", parameterTypes.length);

        Class<?> messageType = parameterTypes[0];
        Preconditions.checkArgument(Message.class.isAssignableFrom(messageType),
                "The first handler parameter must be an instance of %s. Current type is %s",
                Message.class.getName(), messageType.getName());

        Class<?> stateContextType = parameterTypes[1];
        Preconditions.checkArgument(StateContext.class.isAssignableFrom(stateContextType),
                "The second handler parameter must be an instance of %s. Current type is %s",
                StateContext.class.getName(), stateContextType.getName());
    }

    protected void handleDefault(Message event, StateContext<S, E> stateContext) {
        MessagingUtils.throwOnUnmatchedEvent(event, stateContext);
    }

    @Override
    public void handleEvent(Message event, StateContext<S, E> stateContext) {
        if (ignoredEvents.contains(event.getClass())) {
            log.info("Ignoring message {} for class {}", event.getClass(), this.getClass());
            return;
        }
        BiConsumer<Message, StateContext<S, E>> handler = handlerMethods.get(event.getClass());
        if (handler != null) {
            handler.accept(event, stateContext);
        } else {
            handleDefault(event, stateContext);
        }
    }
}
