package ru.yandex.travel.task_processor;

import java.time.Duration;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.micrometer.core.instrument.Gauge;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import io.sentry.Sentry;
import io.sentry.event.Event;
import io.sentry.event.EventBuilder;
import io.sentry.event.interfaces.ExceptionInterface;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.MDC;
import org.springframework.transaction.TransactionDefinition;

import ru.yandex.travel.commons.logging.NestedMdc;
import ru.yandex.travel.commons.metrics.MetricsUtils;
import ru.yandex.travel.spring.tx.ForcedRollbackTxManagerWrapper;
import ru.yandex.travel.spring.tx.ForcedRollbackTxTemplate;
import ru.yandex.travel.tx.utils.TransactionMandatory;

@Slf4j
public class TaskProcessor<K> {

    // TODO (mbobrov): parameterize it
    private static final Duration SCHEDULER_SHUTDOWN_DURATION = Duration.ofSeconds(3);

    private final String name;

    private final PausableTerminationSemaphore taskSemaphore;
    private final ForcedRollbackTxTemplate transactionTemplate;

    private final ScheduledExecutorService scheduler;
    private final Map<Result, Timer> schedulerTimers;

    private final ExecutorService pool;
    private final Map<Result, Timer> poolTimers;

    private final TaskKeyProvider<K> taskKeyProvider;
    private final Task<K> task;

    private final boolean gracefulPoolShutdown;
    private final Duration poolShutdownWaitDuration;

    public TaskProcessor(TaskKeyProvider<K> taskKeyProvider, Task<K> task,
                         ForcedRollbackTxManagerWrapper forcedRollbackTxManagerWrapper,
                         TransactionDefinition transactionDefinition,
                         TaskProcessorProperties properties) {
        log.info("Initializing task processor {}", properties.getName());
        Preconditions.checkArgument(properties.isEnabled(), "Task processor must be enabled");
        Preconditions.checkNotNull(properties.getName(), "Task processor name must be provided");
        Preconditions.checkNotNull(taskKeyProvider, "Task provider must be provided");
        Preconditions.checkNotNull(task, "Task must be provided");
        Preconditions.checkNotNull(properties.getInitialStartDelay(), "Initial start delay must be provided");
        Preconditions.checkNotNull(properties.getScheduleRate(), "Schedule rate must be not null");
        if (!properties.isGracefulPoolShutdown()) {
            Preconditions.checkNotNull(properties.getPoolShutdownWaitDuration(),
                    "Pool shutdown wait duration must be not null");
        }

        var maxActiveTasksCount = properties.getPoolSize();

        this.name = properties.getName();
        this.taskKeyProvider = taskKeyProvider;
        this.task = task;
        Duration initialStartDelay = properties.getInitialStartDelay();
        Duration scheduleRate = properties.getScheduleRate();
        this.gracefulPoolShutdown = properties.isGracefulPoolShutdown();
        this.poolShutdownWaitDuration = properties.getPoolShutdownWaitDuration();

        // setting maximum number of permits to be equal to pool size
        this.taskSemaphore = new PausableTerminationSemaphore("TaskSemaphore-" + name, maxActiveTasksCount);
        this.scheduler = Executors.newSingleThreadScheduledExecutor(
                new ThreadFactoryBuilder().setNameFormat("TaskScheduler-" + name + "-%s")
                        .setDaemon(properties.isDaemonPoolThreads()).build());
        this.pool = Executors.newFixedThreadPool(maxActiveTasksCount,
                new ThreadFactoryBuilder().setNameFormat("TaskPool-" + name + "-%s").build());
        this.transactionTemplate = new ForcedRollbackTxTemplate(forcedRollbackTxManagerWrapper, transactionDefinition);

        Gauge.builder("taskProcessor.activeTasks", taskSemaphore::getPermits)
                .tag("name", name).register(Metrics.globalRegistry);
        Gauge.builder("taskProcessor.maxActiveTasks", () -> maxActiveTasksCount)
                .tag("name", name).register(Metrics.globalRegistry);
        Gauge.builder("taskProcessor.isActive", () -> taskSemaphore.isActive() ? 1 : 0)
                .tag("name", name).register(Metrics.globalRegistry);

        this.schedulerTimers = createResultTimers("taskProcessor.schedulerTimer");
        this.poolTimers = createResultTimers("taskProcessor.poolTimer");

        this.scheduler.scheduleAtFixedRate(this::scheduleTasks,
                initialStartDelay.toMillis(), scheduleRate.toMillis(), TimeUnit.MILLISECONDS);
    }

    private Map<Result, Timer> createResultTimers(String timerName) {
        ImmutableMap.Builder<Result, Timer> builder = new ImmutableMap.Builder<>();
        for (Result r : Result.values()) {
            builder.put(r, Timer.builder(timerName)
                    .serviceLevelObjectives(MetricsUtils.mediumDurationSla())
                    .publishPercentiles(MetricsUtils.higherPercentiles())
                    .tag("name", name).tag("result", r.getValue()).register(Metrics.globalRegistry));
        }
        return builder.build();
    }

    public void pause() {
        try {
            // TODO (mbobrov): think of async pause (must be implemented on semaphore level)
            taskSemaphore.pause();
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        }
    }

    public void resume() {
        taskSemaphore.resume();
    }

    /**
     * Used for monitoring via GaugeService
     */
    @TransactionMandatory
    public long getPendingTasksCount() {
        if (!taskSemaphore.isActive()) {
            return 0;
        }
        return taskKeyProvider.getPendingTasksCount();
    }

    public void shutdown() throws Exception {
        // first shutdown the taskSemaphore to stop any new tasks scheduling
        taskSemaphore.shutdown();
        // shutting down the scheduler
        MoreExecutors.shutdownAndAwaitTermination(scheduler, SCHEDULER_SHUTDOWN_DURATION.toMillis(),
                TimeUnit.MILLISECONDS);
        if (gracefulPoolShutdown) {
            pool.shutdown();
            taskSemaphore.awaitTermination();
        } else {
            MoreExecutors.shutdownAndAwaitTermination(pool, poolShutdownWaitDuration.toMillis(), TimeUnit.MILLISECONDS);
        }
    }

    public String getName() {
        return name;
    }

    private void scheduleTasks() {
        // Being extra cautious here, as any uncaught exception will stop any further execution
        Result scheduleResult = Result.SUCCESS;
        long startedAt = System.nanoTime();
        try {
            getPendingTaskKeys().forEach(taskKey -> {
                try {
                    scheduleOneTaskProcessing(taskKey);
                } catch (Exception e) {
                    log.error("Failed to submit task {} for key {} to pool", name, taskKey, e);
                }
            });
        } catch (Exception e) {
            scheduleResult = Result.FAILURE;
            log.error("Error occurred in task processor {} while scheduling tasks: ", name, e);
        } finally {
            schedulerTimers.get(scheduleResult).record(System.nanoTime() - startedAt, TimeUnit.NANOSECONDS);
        }
    }

    // similar to WorkflowProcessServiceV2.scheduleDrainEventQueue
    private void scheduleOneTaskProcessing(K taskKey) {
        if (taskSemaphore.acquire()) {
            if (taskKeyProvider.acquireTaskLock(taskKey)) {
                pool.submit(() -> {
                    try {
                        doTaskProcessing(taskKey);
                    } finally {
                        // here we use try-catch to avoid permit loss
                        try {
                            taskKeyProvider.releaseTaskLock(taskKey);
                        } catch (Exception e) {
                            log.error("Error releasing task {} lock for key {}", name, taskKey, e);
                        }
                        taskSemaphore.release();
                    }
                });
            } else {
                taskSemaphore.release();
            }
        } else {
            log.debug("skipping task {}, could not acquire semaphore permit", taskKey);
        }
    }

    private void doTaskProcessing(K taskKey) {
        Result result = Result.SUCCESS;
        long startedAt = System.nanoTime();
        try (var mdc = NestedMdc.nestedMdc(Map.of("TaskKey", taskKey.toString(), "TaskProcessorName", name))) {
            try {
                transactionTemplate.execute((ignored) -> {
                    task.process(taskKey);
                    return null;
                });
            } catch (Exception e) {
                result = Result.FAILURE;
                reportCrashExceptionToSentry(e, "TaskProcessor");
                log.error("Failed to process task {} for key {}", name, taskKey, e);
            } finally {
                poolTimers.get(result).record(System.nanoTime() - startedAt, TimeUnit.NANOSECONDS);
            }
        }
    }

    public void reportCrashExceptionToSentry(Exception ex, String scope) {
        EventBuilder eventBuilder = new EventBuilder().withMessage(ex.getMessage())
                .withLevel(Event.Level.ERROR)
                .withSentryInterface(new ExceptionInterface(ex));
        for (Map.Entry<String, String> kv : MDC.getCopyOfContextMap().entrySet()) {
            eventBuilder.withTag(kv.getKey(), kv.getValue());
        }
        eventBuilder.withTag("Scope", scope);
        Sentry.getStoredClient().sendEvent(eventBuilder);
    }

    private Collection<K> getPendingTaskKeys() {
        Collection<K> result = null;
        if (taskSemaphore.isActive()) {
            int availablePermits = taskSemaphore.getAvailablePermits();
            if (availablePermits > 0) {
                result = transactionTemplate.execute((ignored) ->
                        taskKeyProvider.getPendingTaskKeys(availablePermits));
            }
        }
        if (result == null) {
            result = Collections.emptyList();
        }
        return result;
    }

    @RequiredArgsConstructor
    @Getter
    private enum Result {
        SUCCESS("success"), FAILURE("failure");

        private final String value;
    }
}
