package ru.yandex.solomon.alert.unroll;

import java.time.Clock;
import java.time.Duration;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import javax.annotation.ParametersAreNonnullByDefault;
import javax.annotation.WillNotClose;
import javax.annotation.concurrent.ThreadSafe;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.monlib.metrics.registry.MetricRegistry;
import ru.yandex.solomon.alert.domain.Alert;
import ru.yandex.solomon.alert.domain.AlertKey;
import ru.yandex.solomon.alert.rule.AlertRuleDeadlines;
import ru.yandex.solomon.alert.rule.UnrollDeadlines;
import ru.yandex.solomon.util.async.InFlightLimiter;
import ru.yandex.solomon.util.time.DurationUtils;

/**
 * @author Vladimir Gordiychuk
 */
@ParametersAreNonnullByDefault
public class UnrollExecutorImpl implements UnrollExecutor {
    private static final Logger logger = LoggerFactory.getLogger(UnrollExecutorImpl.class);

    private final Clock clock;
    @WillNotClose
    private final ScheduledExecutorService executor;
    private final UnrollExecutorMetrics metrics;
    private final MultiAlertUnrollFactory unrollFactory;
    private final Duration unrollTimeout;
    private final long unrollIntervalMillis;
    private final ConcurrentMap<AlertKey, Task> taskByAlertId = new ConcurrentHashMap<>();
    private final InFlightLimiter inflightLimiter = new InFlightLimiter(10);

    public UnrollExecutorImpl(
            Clock clock,
            @WillNotClose ScheduledExecutorService executor,
            MetricRegistry registry,
            MultiAlertUnrollFactory unrollFactory,
            Duration unrollInterval,
            Duration unrollTimeout)
    {
        this.clock = clock;
        this.executor = executor;
        this.metrics = new UnrollExecutorMetrics(registry);
        this.unrollFactory = unrollFactory;
        this.unrollIntervalMillis = unrollInterval.toMillis();
        this.unrollTimeout = unrollTimeout;

        metrics.observe(inflightLimiter);
    }

    @Override
    public void unroll(Alert alert, UnrollConsumer consumer) {
        Task task = prepareTask(alert, consumer);
        schedule(task, ThreadLocalRandom.current().nextLong(unrollIntervalMillis));
    }

    @Override
    public void unrollNow(Alert alert, UnrollConsumer consumer) {
        Task task = prepareTask(alert, consumer);
        executor.execute(() -> {
            unroll(task).whenComplete((ignore, e) -> {
                schedule(task, ThreadLocalRandom.current().nextLong(unrollIntervalMillis));
            });
        });
    }

    @Override
    public void close() {
        Iterator<Map.Entry<AlertKey, Task>> it = taskByAlertId.entrySet().iterator();
        while (it.hasNext()) {
            Task task = it.next().getValue();
            task.cancel();
            it.remove();
        }
    }

    private Task prepareTask(Alert alert, UnrollConsumer consumer) {
        MultiAlertUnrollFunction fn = unrollFactory.create(alert);
        Task task = new Task(alert.getKey(), fn, consumer);
        Task prev = taskByAlertId.put(alert.getKey(), task);
        if (prev != null) {
            prev.cancel();
        }
        return task;
    }

    private void schedule(Task task, long delayMillis) {
        if (task.isCanceled()) {
            removeCanceled(task);
            return;
        }

        if (logger.isDebugEnabled()) {
            logger.debug("{} will be unroll next time at {}", task.key, clock.instant().plusMillis(delayMillis));
        }
        task.future = executor.schedule(() -> execScheduledUnroll(task), delayMillis, TimeUnit.MILLISECONDS);
    }

    private void execScheduledUnroll(Task task) {
        inflightLimiter.run(() -> {
            if (task.isCanceled()) {
                removeCanceled(task);
                return CompletableFuture.completedFuture(null);
            }

            logger.debug("{} start unrolling", task.key);
            return unroll(task).handle((ignore, e) -> {
                var prev = task.prevUnroll;
                if (prev == null || !prev.allowDelete) {
                    long delay = DurationUtils.backoff(1_000L, unrollIntervalMillis, task.attempt.incrementAndGet());
                    schedule(task, DurationUtils.randomize(delay));
                    return null;
                }

                task.attempt.set(0);
                schedule(task, DurationUtils.randomize(unrollIntervalMillis));
                return null;
            });
        });
    }

    private CompletableFuture<Void> unroll(Task task) {
        AlertRuleDeadlines deadlines = UnrollDeadlines.of(clock.instant(), unrollTimeout);

        long startMillis = System.currentTimeMillis();
        return CompletableFuture.completedFuture(deadlines)
                .thenCompose(task.fn::unroll)
                .thenAccept(fresh -> {
                    if (task.isCanceled()) {
                        removeCanceled(task);
                        return;
                    }

                    logger.debug("{} alert unroll into {} groups", task.key, fresh.labels.size());
                    if (!Objects.equals(task.prevUnroll, fresh)) {
                        task.consumer.accept(fresh);
                        task.prevUnroll = fresh;
                    }
                })
                .whenComplete((ignore, e) -> {
                    if (e != null) {
                        logger.error("Failed unrolling for alert " + task.key, e);
                    }
                    long elapsedMillis = System.currentTimeMillis() - startMillis;
                    boolean failed = e != null;
                    metrics.unrollFinished(task.key.getProjectId(), elapsedMillis, failed);
                });
    }

    private void removeCanceled(Task task) {
        taskByAlertId.remove(task.key, task);
    }

    @ThreadSafe
    private static class Task {
        private final AlertKey key;
        private final MultiAlertUnrollFunction fn;
        private final UnrollConsumer consumer;
        private volatile UnrollResult prevUnroll = null;
        private volatile boolean canceled;
        private volatile Future<?> future;
        private final AtomicInteger attempt = new AtomicInteger();

        private Task(AlertKey key, MultiAlertUnrollFunction fn, UnrollConsumer consumer) {
            this.key = key;
            this.fn = fn;
            this.consumer = consumer;
        }

        private void cancel() {
            this.canceled = true;
            Future<?> future = this.future;
            if (future != null) {
                future.cancel(false);
            }
        }

        public boolean isCanceled() {
            return consumer.isCanceled() || canceled;
        }
    }
}
