package ru.yandex.solomon.scheduler;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;

import io.grpc.Status;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.solomon.locks.LockDetail;
import ru.yandex.solomon.scheduler.ProgressOperator.Fail;
import ru.yandex.solomon.scheduler.ProgressOperator.Ok;
import ru.yandex.solomon.scheduler.ProgressOperator.Stop;
import ru.yandex.solomon.util.actors.PingActorRunner;
import ru.yandex.solomon.util.future.RetryConfig;

import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.CompletableFuture.failedFuture;
import static ru.yandex.misc.concurrent.CompletableFutures.unwrapCompletionException;
import static ru.yandex.solomon.util.future.RetryCompletableFuture.runWithRetries;

/**
 * @author Vladimir Gordiychuk
 */
public class TaskSchedulerImpl implements TaskScheduler, AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(TaskSchedulerImpl.class);

    private static final RetryConfig RETRY_CONFIG = RetryConfig.DEFAULT
        .withNumRetries(10)
        .withDelay(TimeUnit.MILLISECONDS.toMillis(1))
        .withMaxDelay(TimeUnit.MILLISECONDS.toMillis(50))
        .withExceptionFilter(t -> !(unwrapCompletionException(t) instanceof FailedRescheduleException));

    private final int maxInflight;
    private final TaskDeps deps;
    private final PingActorRunner actor;

    private State state = State.LIST_SCHEDULED;
    private List<ScheduledTask> scheduledCandidates = List.of();
    private Map<String, LockDetail> runningTasks = Map.of();
    private final ConcurrentHashMap<String, TaskPipeline> localRunningTasks = new ConcurrentHashMap<>();
    private final TaskMetrics metrics = new TaskMetrics();

    public TaskSchedulerImpl(Duration interval, int maxInflight, TaskDeps deps) {
        this.deps = deps;
        this.maxInflight = maxInflight;
        this.actor = PingActorRunner.newBuilder()
                .operation("task_scheduler")
                .executor(deps.executor)
                .pingInterval(interval)
                .timer(deps.timer)
                .onPing(this::act)
                .build();

        this.actor.schedule();
    }

    public TaskMetrics metrics() {
        return metrics;
    }

    @Override
    public CompletableFuture<Optional<Task>> getTask(String taskId) {
        return deps.dao.get(taskId);
    }

    @Override
    public CompletableFuture<Void> schedule(Task task) {
        return deps.dao.add(task).thenApply(success -> {
            if (success) {
                metrics.getByType(task.type()).schedule.inc();
                if (task.executeAt() <= deps.clock.millis()) {
                    forceAct();
                }
            }
            return null;
        });
    }

    @Override
    public CompletableFuture<Boolean> reschedule(String taskId, long executeAt, ProgressOperator progressOperator) {
        var future = new CompletableFuture<Boolean>();

        runWithRetries(() -> tryReschedule(taskId, executeAt, progressOperator), RETRY_CONFIG)
            .whenComplete((r, t) -> {
                if (t != null) {
                    if (unwrapCompletionException(t) instanceof FailedRescheduleException fre) {
                        future.completeExceptionally(fre.getCause());
                    } else {
                        future.completeExceptionally(t);
                    }
                } else {
                    future.complete(r);
                }
            });

        return future;
    }

    public void forceAct() {
        actor.forcePing();
    }

    private CompletableFuture<Boolean> tryReschedule(
        String taskId,
        long executeAt,
        ProgressOperator progressOperator)
    {
        return getTask(taskId).thenCompose(t -> {
            var task = t.orElseThrow(() -> Status.NOT_FOUND.withDescription(taskId).asRuntimeException());
            if (task.state() == Task.State.COMPLETED) {
                return completedFuture(Boolean.FALSE);
            }

            var result = progressOperator.apply(task.progress());
            if (result instanceof Stop) {
                return completedFuture(Boolean.FALSE);
            }
            if (result instanceof Fail fail) {
                return failedFuture(new FailedRescheduleException(fail.cause()));
            }

            return deps.dao.rescheduleExternally(taskId, executeAt, ((Ok) result).newProgress(), task.version())
                .thenCompose(success -> {
                    if (success) {
                        metrics.getByType(task.type()).rescheduleExternally.inc();
                        if (task.state() == Task.State.SCHEDULED && executeAt <= deps.clock.millis()) {
                            forceAct();
                        }

                        return completedFuture(Boolean.TRUE);
                    }

                    return failedFuture(Status.ABORTED.withDescription("unable to reschedule").asRuntimeException());
                });
        });
    }

    private CompletableFuture<Void> act(int attempt) {
        if (localRunningTasks.size() == maxInflight) {
            return completedFuture(null);
        }

        return stepListScheduled()
                .thenRun(this::filterCandidates)
                .thenCompose(unused -> stepListRunning())
                .thenRun(this::filterCandidates)
                .thenRun(this::stepStart);
    }

    private CompletableFuture<Void> stepListScheduled() {
        if (state != State.LIST_SCHEDULED) {
            return completedFuture(null);
        }

        var limit = Math.max(runningTasks.size() + (maxInflight * 100), 1000);
        return deps.dao.listScheduled(deps.clock.millis(), limit)
                .thenAccept(tasks -> {
                    scheduledCandidates = tasks;
                    metrics.updateScheduled(tasks);
                    state = state.next();
                });
    }

    private CompletableFuture<Void> stepListRunning() {
        if (state != State.LIST_RUNNING) {
            return completedFuture(null);
        }

        if (scheduledCandidates.isEmpty()) {
            state = state.next();
            return completedFuture(null);
        }

        return deps.locksDao.listLocks()
                .thenAccept(locks -> {
                    runningTasks = locks.stream().collect(Collectors.toMap(LockDetail::id, Function.identity()));
                    state = state.next();
                });
    }

    private void stepStart() {
        for (var scheduled : scheduledCandidates) {
            if (runningTasks.size() == maxInflight) {
                break;
            }

            var permit = deps.taskExecutor.acquire(scheduled.id(), scheduled.type(), scheduled.params());
            if (permit == null) {
                continue;
            }

            String id = scheduled.id();
            var typeMetrics = metrics.getByType(scheduled.type());

            TaskPipeline pipeline = new TaskPipeline(id, typeMetrics, deps);
            localRunningTasks.put(scheduled.id(), pipeline);

            long startNanos = System.nanoTime();
            var future = pipeline.start();
            typeMetrics.lagMillis.record(deps.clock.millis() - scheduled.executeAt());
            typeMetrics.pipeline.forFuture(future);

            future.whenComplete((status, e) -> {
                localRunningTasks.remove(id, pipeline);
                if (e != null) {
                    typeMetrics.completePipeline(Status.fromThrowable(e).getCode());
                    logger.warn("Task pipeline {} failed", id, e);
                    permit.release();
                } else if (!status.isOk()) {
                    typeMetrics.completePipeline(status.getCode());
                    logger.warn("Task {} failed: {}", id, status);
                    permit.cancel();
                } else {
                    typeMetrics.elapsedMillis.record(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNanos));
                    typeMetrics.completePipeline(status.getCode());
                    logger.debug("Task {} completed", id);
                    permit.release();
                }

                forceAct();
            });
        }

        state = state.next();
    }

    private void filterCandidates() {
        if (scheduledCandidates.isEmpty()) {
            return;
        }

        var result = new ArrayList<ScheduledTask>(scheduledCandidates.size());
        for (var task : scheduledCandidates) {
            if (localRunningTasks.containsKey(task.id())) {
                continue;
            }

            var lock = runningTasks.get(task.id());
            if (lock != null && lock.expiredAt().toEpochMilli() > deps.clock.millis()) {
                continue;
            }

            result.add(task);
        }

        Collections.shuffle(result);
        scheduledCandidates = result;
    }

    @Override
    public void close() {
        actor.close();
    }

    public enum State {
        LIST_SCHEDULED, LIST_RUNNING, START;

        private State next() {
            return next(this);
        }

        public static State next(State state) {
            return switch (state) {
                case LIST_SCHEDULED -> LIST_RUNNING;
                case LIST_RUNNING -> START;
                case START -> LIST_SCHEDULED;
            };
        }
    }
}
