package ru.yandex.infra.controller;

import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;

import org.slf4j.Logger;
import org.slf4j.helpers.NOPLogger;

import ru.yandex.infra.controller.metrics.GaugeRegistry;
import ru.yandex.infra.controller.metrics.GolovanableGauge;
import ru.yandex.infra.controller.util.ExitUtils;

public class RepeatedTask {
    public static final String METRIC_FAILED_ITERATIONS_COUNT = "failed_iterations";
    public static final String METRIC_FINISHED_ITERATIONS_COUNT = "finished_iterations";
    public static final String METRIC_HUNG_ITERATIONS_COUNT = "hung_iteration";
    public static final String METRIC_INTERVAL_BETWEEN_ITERATIONS_TIME = "interval_between_iterations_ms";
    public static final String METRIC_LAST_ITERATION_DURATION_TIME = "last_iteration_duration_ms";
    public static final String METRIC_CURRENT_ITERATION_DURATION_TIME = "current_iteration_duration_ms";

    private final AtomicLong metricFailedIterationsCount = new AtomicLong();
    private final AtomicLong metricFinishedIterationsCount = new AtomicLong();
    private final AtomicLong metricHungIterationsCount = new AtomicLong();
    private long lastIterationStartTimeMillis;
    private Long metricIntervalBetweenIterationsMilliseconds;
    private Long metricLastIterationDurationMilliseconds;

    private final Logger logger;
    private final Logger infoLogger;
    private final Supplier<CompletableFuture<?>> taskSupplier;
    private final Duration updateInterval;
    private final Duration singleIterationTimeout;
    private final ScheduledExecutorService executor;
    private final AtomicBoolean started = new AtomicBoolean(false);

    public RepeatedTask(Supplier<CompletableFuture<?>> taskSupplier,
                        Duration updateInterval,
                        Duration singleIterationTimeout,
                        ScheduledExecutorService executor,
                        Optional<GaugeRegistry> metricsRegistryOptional,
                        Logger logger,
                        boolean logInfoMessages) {
        this.logger = logger;
        this.infoLogger = logInfoMessages ? logger : NOPLogger.NOP_LOGGER;
        this.taskSupplier = taskSupplier;
        this.updateInterval = updateInterval;
        this.singleIterationTimeout = singleIterationTimeout;
        this.executor = executor;

        metricsRegistryOptional.ifPresent(metricsRegistry -> {
            metricsRegistry.add(METRIC_FAILED_ITERATIONS_COUNT, new GolovanableGauge<>(metricFailedIterationsCount::get, "dmmm"));
            metricsRegistry.add(METRIC_FINISHED_ITERATIONS_COUNT, new GolovanableGauge<>(metricFinishedIterationsCount::get, "dmmm"));
            metricsRegistry.add(METRIC_HUNG_ITERATIONS_COUNT, new GolovanableGauge<>(metricHungIterationsCount::get, "dmmm"));
            metricsRegistry.add(METRIC_INTERVAL_BETWEEN_ITERATIONS_TIME, new GolovanableGauge<>(() -> metricIntervalBetweenIterationsMilliseconds, "axxx"));
            metricsRegistry.add(METRIC_LAST_ITERATION_DURATION_TIME, new GolovanableGauge<>(() -> metricLastIterationDurationMilliseconds, "axxx"));
            metricsRegistry.add(METRIC_CURRENT_ITERATION_DURATION_TIME, new GolovanableGauge<>(() -> lastIterationStartTimeMillis != 0 ? System.currentTimeMillis() - lastIterationStartTimeMillis : null, "axxx"));
        });
    }

    public RepeatedTask(Runnable singleIterationTask,
                        Duration updateInterval,
                        Duration singleIterationTimeout,
                        ScheduledExecutorService executor,
                        Optional<GaugeRegistry> metricsRegistryOptional,
                        Logger logger,
                        boolean logInfoMessages) {
        this(() -> {
            CompletableFuture<?> future = new CompletableFuture<>();
            executor.submit(() -> {
                try {
                    singleIterationTask.run();
                    future.complete(null);
                } catch (Throwable ex) {
                    future.completeExceptionally(ex);
                }
            });
            return future;
        }, updateInterval, singleIterationTimeout, executor, metricsRegistryOptional, logger, logInfoMessages);
    }


    public void start() {
        if (started.compareAndSet(false, true)) {
            scheduleNextOrFail(0);
        } else {
            logger.warn("Task is already scheduled.");
        }
    }

    public void stop() {
        if (!started.compareAndSet(true, false)) {
            logger.warn("Task is not scheduled, nothing to stop");
        }
    }

    private void runSingleIteration() {
        long currentTimeMillis = System.currentTimeMillis();
        if (lastIterationStartTimeMillis != 0) {
            metricIntervalBetweenIterationsMilliseconds = currentTimeMillis - lastIterationStartTimeMillis;
        }
        lastIterationStartTimeMillis = currentTimeMillis;

        try {
            infoLogger.info("Starting iteration {}", metricFinishedIterationsCount.get());

            taskSupplier.get()
                .orTimeout(singleIterationTimeout.toNanos(), TimeUnit.NANOSECONDS)
                .whenComplete((result, error) -> completeIterationAndScheduleNext(error));
        } catch (Exception error) {
            completeIterationAndScheduleNext(error);
        }
    }

    private void completeIterationAndScheduleNext(Throwable error) {
        if (error != null) {
            if(error instanceof TimeoutException) {
                metricHungIterationsCount.incrementAndGet();
            }
            metricFailedIterationsCount.incrementAndGet();
            logger.error(String.format("Iteration %d was failed", metricFinishedIterationsCount.get()), error);
        }

        metricLastIterationDurationMilliseconds = System.currentTimeMillis() - lastIterationStartTimeMillis;
        infoLogger.info("Finished iteration {} in {} ms", metricFinishedIterationsCount.getAndIncrement(), metricLastIterationDurationMilliseconds);

        long nextDelay = Math.max(0, updateInterval.toMillis() - metricLastIterationDurationMilliseconds);
        scheduleNextOrFail(nextDelay);
    }

    private void scheduleNextOrFail(long delayBeforeNextExecutionMilliseconds) {
        if (!started.get()) {
            logger.warn("Task scheduling was stopped. Skipping the next iteration.");
            return;
        }

        try {
            executor.schedule(this::runSingleIteration, delayBeforeNextExecutionMilliseconds, TimeUnit.MILLISECONDS);
        } catch (Exception exception) {
            logger.error("Scheduling of next iteration was failed. Forcing application shutdown", exception);
            ExitUtils.gracefulExit(ExitUtils.EXECUTOR_SCHEDULING_FAILURE);
        }
    }

}
