package ru.yandex.solomon.alert.evaluation;

import java.time.Clock;
import java.time.Instant;
import java.util.Map;
import java.util.Objects;
import java.util.OptionalLong;
import java.util.Queue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.PriorityBlockingQueue;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;

import javax.annotation.Nullable;
import javax.annotation.WillNotClose;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Throwables;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.misc.actor.ActorRunner;
import ru.yandex.misc.concurrent.CompletableFutures;
import ru.yandex.monlib.metrics.registry.MetricRegistry;
import ru.yandex.solomon.alert.EvaluationStatus;
import ru.yandex.solomon.alert.domain.AlertKey;
import ru.yandex.solomon.alert.executor.local.LocalAlertExecutorOptions;
import ru.yandex.solomon.alert.rule.AlertRule;
import ru.yandex.solomon.alert.rule.AlertRuleDeadlines;
import ru.yandex.solomon.alert.rule.AlertRuleFairDeadlines;
import ru.yandex.solomon.alert.rule.EvaluationState;
import ru.yandex.solomon.staffOnly.annotations.ManagerMethod;
import ru.yandex.solomon.util.ExceptionUtils;
import ru.yandex.solomon.util.time.DurationUtils;

/**
 * @author Vladimir Gordiychuk
 */
public class TaskExecutorImpl implements TaskExecutor {
    private static final Logger logger = LoggerFactory.getLogger(TaskExecutorImpl.class);
    private static final EvaluationStatus ALERT_RULE_DEADLINE = EvaluationStatus.DEADLINE
        .withDescription("Common alert rule timeout");

    private final Clock clock;
    @WillNotClose
    private final ExecutorService executor;
    @WillNotClose
    private final ScheduledExecutorService timer;
    private final LocalAlertExecutorOptions options;
    private final TaskExecutorMetrics metrics;

    /**
     * Contains a list of task sorted by evaluation time
     */
    private final BlockingQueue<Task> queue = new PriorityBlockingQueue<>(10_000);

    /**
     * Contains a list of task not evaluated before and required warmup
     */
    private final BlockingQueue<Task> warmupQueue = new PriorityBlockingQueue<>(1000);

    private final ConcurrentMap<AlertKey, Task> taskByAlertKey = new ConcurrentHashMap<>(10_000);
    private final AtomicInteger countActiveTask = new AtomicInteger(0);
    private final AtomicInteger countWarmupTask = new AtomicInteger(0);

    private final ActorRunner actor;
    private final TaskEvaluationMetrics evaluationMetrics;
    private long scheduledAt;
    private ScheduledFuture<?> scheduled;
    private volatile boolean cancel;
    private volatile long maxAlertRule;

    public TaskExecutorImpl(
            Clock clock,
            @WillNotClose ExecutorService executor,
            @WillNotClose ScheduledExecutorService timer,
            LocalAlertExecutorOptions options,
            MetricRegistry registry,
            TaskEvaluationMetrics evaluationMetrics)
    {
        this.clock = clock;
        this.executor = executor;
        this.timer = timer;
        this.options = options;
        this.maxAlertRule = options.getMaxAlertRules();
        this.metrics = new TaskExecutorMetrics(registry);
        this.actor = new ActorRunner(this::act, executor);
        this.evaluationMetrics = evaluationMetrics;
        metrics.lazyEvalInFlight(countActiveTask::get);
        metrics.lazyWarmupInFlight(countWarmupTask::get);
        metrics.lazyQueueSize(taskByAlertKey::size);
        metrics.lazyWarmupQueueSize(warmupQueue::size);
        metrics.lazyRegularQueueSize(queue::size);
        if (maxAlertRule > 0) {
            metrics.lazyQueueAvailableCapacity(() -> Math.max(0L, maxAlertRule - taskByAlertKey.size()));
        }
    }

    public TaskExecutorMetrics getMetrics() {
        return metrics;
    }

    public Map<AlertKey, Task> getTasks() {
        return taskByAlertKey;
    }

    @ManagerMethod
    public void setMaxAlertRule(int limit) {
        this.maxAlertRule = limit;
        if (limit == 0) {
            return;
        }

        var it = taskByAlertKey.values().iterator();
        while (it.hasNext() && limit < taskByAlertKey.size()) {
            var task = it.next();
            task.cancel();
            it.remove();
        }
    }

    @VisibleForTesting
    void scheduleAct() {
        actor.schedule();
    }

    @Override
    public void scheduleNewTask(Task newTask) {
        Task actual = taskByAlertKey.compute(newTask.getAlertKey(), (k, v) -> {
            if (v == null || !v.isActive()) {
                return newTask;
            }

            int compare = v.getConsumer().getSeqNo().compareTo(newTask.getConsumer().getSeqNo());
            if (compare <= 0) {
                v.cancel();
                return newTask;
            } else {
                return v;
            }
        });

        if (Objects.equals(actual, newTask)) {
            if (isMaxTaskLimitReach()) {
                taskByAlertKey.remove(newTask.getAlertKey(), newTask);
                newTask.cancel();
                return;
            }

            scheduleWarmup(newTask);
        } else {
            newTask.cancel();
        }
    }

    private boolean isMaxTaskLimitReach() {
        long limit = this.maxAlertRule;
        if (limit == 0) {
            return false;
        }

        return taskByAlertKey.size() > limit;
    }

    @Override
    public boolean cancelTask(AlertKey alertKey) {
        Task previous = taskByAlertKey.remove(alertKey);
        if (previous == null) {
            return false;
        }

        previous.cancel();
        return true;
    }

    private void schedule(Task task) {
        queue.add(task);
        if (logger.isDebugEnabled()) {
            logger.debug("{} scheduled on evaluate {} hash {}", task.getAlertKey(), task.getIterationTime(), task.hashCode());
        }
        actor.schedule();
    }

    private void scheduleWarmup(Task task) {
        warmupQueue.add(task);
        if (logger.isDebugEnabled()) {
            logger.debug("{} scheduled warmup on evaluate {} hash {}", task.getAlertKey(), task.getIterationTime(), task.hashCode());
        }
        actor.schedule();
    }

    private void rescheduleFirstFromQueue() {
        OptionalLong delay = getNextTaskDelay();
        if (delay.isEmpty()) {
            return;
        }

        scheduleActWithDelay(delay.getAsLong());
    }

    private OptionalLong getDelayToEvaluate(BlockingQueue<Task> queue) {
        Task task = queue.peek();
        if (task == null) {
            return OptionalLong.empty();
        }

        long delayMillis = task.getEvaluateAt() - clock.millis();
        return OptionalLong.of(delayMillis);
    }

    private OptionalLong getNextTaskDelay() {
        OptionalLong warmup = getDelayToEvaluate(warmupQueue);
        OptionalLong regular = getDelayToEvaluate(queue);
        return Stream.of(warmup, regular)
            .filter(OptionalLong::isPresent)
            .mapToLong(OptionalLong::getAsLong)
            .min();
    }

    private void scheduleActWithDelay(long delayMillis) {
        if (cancel) {
            return;
        }

        if (delayMillis <= 0) {
            actor.schedule();
            return;
        }

        long now = clock.millis();
        long evaluateAt = now + delayMillis;
        if (scheduledAt <= now || scheduledAt > evaluateAt) {
            if (scheduled != null) {
                scheduled.cancel(false);
            }
            scheduledAt = evaluateAt;
            try {
                scheduled = timer.schedule(this::act, delayMillis, TimeUnit.MILLISECONDS);
            } catch (RejectedExecutionException e) {
                // it's ok
            }
        }
    }

    private void scheduleNextIteration(Task prevTask, Task nextTask) {
        if (taskByAlertKey.replace(prevTask.getAlertKey(), prevTask, nextTask)) {
            schedule(nextTask);
        }
    }

    private void completeTask(Task task, EvaluationStatus status, long startTimeNano) {
        if (status == ALERT_RULE_DEADLINE) {
            evaluationMetrics.evalTimeout(task.getAlert());
        }
        metrics.markEvaluationCompleted();
        evaluationMetrics.completeEval(task.getAlert(), task.getStatus(), status, startTimeNano);
        if (logger.isDebugEnabled()) {
            logger.debug("{} - for {} at {}, details {}",
                status.getCode(),
                task.getAlertKey(),
                task.getIterationTime(),
                status.getDescription());
        }

        if (!task.isActive()) {
            taskByAlertKey.remove(task.getAlertKey(), task);
            return;
        }

        Task nextTask = prepareNextTask(task, status);
        try {
            final var reportState = nextTask.getAttempt() > 0
                ? task.nextState(status)
                : nextTask.getState();
            var from = task.getState() == null ? null : task.getState().getStatus().getCode();
            var to = status.getCode();
            if (from != to) {
                logger.info("{} -> {} for {} at {}, details {}", from, to, task.getAlertKey(), task.getIterationTime(), status);
            }
            task.getConsumer().consume(Objects.requireNonNull(reportState));
        } catch (Throwable e) {
            logger.error("Failed report about evaluation status {} for alert {}", status, task.getAlert(), e);
        }
        scheduleNextIteration(task, nextTask);
    }

    private Task prepareNextTask(Task task, EvaluationStatus status) {
        long now = clock.millis();
        if (isRetryable(status)) {
            long nextIteration = truncateMaxLag(now, task.getIterationTimeMillis());
            long backoff = DurationUtils.backoff(1_000, options.getEvalIntervalMillis(), task.getAttempt() + 1);
            long evaluateAt = now + DurationUtils.randomize(backoff);
            logger.info("{} retry attempt {} caused by {}", task.getAlertKey(), task.getAttempt(), status);
            return task.withRetry(evaluateAt, nextIteration);
        }

        final EvaluationState newState = task.nextState(status);
        long nextMillis = truncateMaxLag(now, task.getIterationTimeMillis() + options.getEvalIntervalMillis());
        long jitter = ThreadLocalRandom.current().nextLong(options.getMaxEvaluationJitterMillis());
        long evaluateAt = nextMillis > now ? nextMillis + jitter : now;
        return task.withEvalTime(evaluateAt, nextMillis, newState);
    }

    private boolean isRetryable(EvaluationStatus status) {
        return switch (status.getErrorCode()) {
            case UNAVAILABLE, DEADLINE -> true;
            default -> false;
        };
    }

    private long truncateMaxLag(long now, long iterationMillis) {
        if (now < iterationMillis) {
            return iterationMillis;
        }

        long lagMillis = now - iterationMillis;
        if (lagMillis < options.getMaxEvaluationLagMillis()) {
            return iterationMillis;
        }

        return now - ThreadLocalRandom.current().nextLong(options.getEvalIntervalMillis());
    }

    private boolean isQuoteOverflow() {
        if (countActiveTask.get() >= options.getMaxConcurrentAlertRules()) {
            metrics.evalOverflow();
            return true;
        }

        return false;
    }

    @Nullable
    private Task pollNext(Queue<Task> queue) {
        while (true) {
            Task task = queue.poll();
            if (task == null) {
                return null;
            }

            if (task.isActive()) {
                return task;
            }

            taskByAlertKey.remove(task.getAlertKey(), task);
        }
    }

    private void act() {
        if (cancel) {
            if (scheduled != null) {
                scheduled.cancel(false);
            }
            warmupQueue.clear();
            queue.clear();
            return;
        }

        boolean repeat;
        do {
            repeat = warmupRun() || regularRun();
        } while (repeat);
        rescheduleFirstFromQueue();
    }

    private boolean warmupRun() {
        if (isQuoteOverflow()) {
            return false;
        }

        Task task = pollNext(warmupQueue);
        if (task == null) {
            return false;
        }

        long now = clock.millis();
        if (task.getEvaluateAt() > now) {
            warmupQueue.add(task);
            return false;
        }

        if (countWarmupTask.get() >= options.getMaxConcurrentWarmupAlertRules()) {
            warmupQueue.add(task);
            return false;
        }

        countActiveTask.incrementAndGet();
        countWarmupTask.incrementAndGet();
        runSafeTask(now, task).whenComplete((ignore1, ignore2) -> {
            countActiveTask.decrementAndGet();
            countWarmupTask.decrementAndGet();
            actor.schedule();
        });

        return true;
    }

    private boolean regularRun() {
        if (isQuoteOverflow()) {
            return false;
        }

        Task task = pollNext(queue);
        if (task == null) {
            return false;
        }

        long now = clock.millis();
        if (task.getEvaluateAt() > now) {
            queue.add(task);
            return false;
        }

        countActiveTask.incrementAndGet();
        runSafeTask(now, task).whenComplete((ignore1, ignore2) -> {
            countActiveTask.decrementAndGet();
            actor.schedule();
        });

        return true;
    }

    private CompletableFuture<?> runSafeTask(long now, Task task) {
        return CompletableFuture.supplyAsync(() -> runTask(now, task), executor)
            .thenCompose(f -> f);
    }

    private CompletableFuture<Void> runTask(long now, Task task) {
        AlertRule rule = task.getRule();
        Instant iteration = task.getIterationTime();
        long lag = now - iteration.toEpochMilli();
        AlertRuleDeadlines deadlines = AlertRuleFairDeadlines.adjustToLag(now, lag, options.getAlertRuleTimeoutMillis());
        logger.debug("{} - running now {}, deadline {}", task.getAlertKey(), iteration, deadlines);
        evaluationMetrics.startEval(rule.getAlert(), task.getIterationTimeMillis());
        long startTimeNano = System.nanoTime();
        return CompletableFutures.safeCall(() -> rule.eval(iteration, deadlines))
            .exceptionally(e -> {
                logger.debug("{} - for {} caused by {}",
                    EvaluationStatus.ERROR, task.getAlertKey(), e.getMessage(), e);
                return EvaluationStatus.ERROR.withDescription(Throwables.getStackTraceAsString(e));
            })
            .completeOnTimeout(ALERT_RULE_DEADLINE, options.getAlertRuleTimeoutMillis(), TimeUnit.MILLISECONDS)
            .thenAccept(status -> completeTask(task, status, startTimeNano))
            .exceptionally(throwable -> {
                logger.error("Exception on eval no catch", throwable);
                ExceptionUtils.uncaughtException(throwable);
                return null;
            });
    }

    @Override
    public void close() {
        cancel = true;
        var it = taskByAlertKey.values().iterator();
        while (it.hasNext()) {
            var task = it.next();
            task.cancel();
            it.remove();
        }
        actor.schedule();
    }
}
