package ru.yandex.mail.micronaut.common.context;

import io.micronaut.context.annotation.Context;
import io.micronaut.context.annotation.Requires;
import io.micronaut.context.event.BeanInitializedEventListener;
import io.micronaut.context.event.BeanInitializingEvent;
import io.micronaut.scheduling.instrument.InstrumentedExecutorService;
import io.micronaut.scheduling.instrument.InstrumentedScheduledExecutorService;
import lombok.val;
import org.apache.logging.log4j.ThreadContext;

import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;

import static ru.yandex.mail.micronaut.common.context.ContextManager.clearMdcContext;
import static ru.yandex.mail.micronaut.common.context.ContextManager.switchMdcContext;

@Context
@Requires(classes = ThreadContext.class)
class ContextInstrumenter implements BeanInitializedEventListener<ExecutorService> {
    private static <T> Callable<T> instrument(Callable<T> callable) {
        val context = ThreadContext.getImmutableContext();
        if (context.isEmpty()) {
            return callable;
        } else {
            return () -> {
                switchMdcContext(context);
                try {
                    return callable.call();
                } finally {
                    clearMdcContext();
                }
            };
        }
    }

    private static Runnable instrument(Runnable runnable) {
        val context = ThreadContext.getImmutableContext();
        if (context.isEmpty()) {
            return runnable;
        } else {
            return () -> {
                switchMdcContext(context);
                runnable.run();
            };
        }
    }

    @Override
    public ExecutorService onInitialized(BeanInitializingEvent<ExecutorService> event) {
        val executor = event.getBean();
        if (executor instanceof ScheduledExecutorService) {
            return new InstrumentedScheduledExecutorService() {
                @Override
                public ScheduledExecutorService getTarget() {
                    return (ScheduledExecutorService) executor;
                }

                @Override
                public <T> Callable<T> instrument(Callable<T> task) {
                    return ContextInstrumenter.instrument(task);
                }

                @Override
                public Runnable instrument(Runnable command) {
                    return ContextInstrumenter.instrument(command);
                }
            };
        } else {
            return new InstrumentedExecutorService() {
                @Override
                public ExecutorService getTarget() {
                    return executor;
                }

                @Override
                public <T> Callable<T> instrument(Callable<T> task) {
                    return ContextInstrumenter.instrument(task);
                }

                @Override
                public Runnable instrument(Runnable command) {
                    return ContextInstrumenter.instrument(command);
                }
            };
        }
    }
}
