package ru.yandex.solomon.scheduler;

import java.time.Duration;
import java.time.Instant;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
import java.util.function.Supplier;

import com.google.common.base.Strings;
import com.google.common.base.Throwables;
import com.google.protobuf.Message;
import io.grpc.Status;
import io.grpc.Status.Code;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.misc.concurrent.CompletableFutures;
import ru.yandex.solomon.util.Proto;
import ru.yandex.solomon.util.actors.PingActorRunner;
import ru.yandex.solomon.util.future.RetryCompletableFuture;
import ru.yandex.solomon.util.future.RetryConfig;

import static java.util.concurrent.CompletableFuture.completedFuture;

/**
 * @author Vladimir Gordiychuk
 */
public class TaskPipeline {
    private static final Logger logger = LoggerFactory.getLogger(TaskPipeline.class);

    private static final long LEASE_TIME_MILLIS = 30_000;
    private static final long LEASE_EXTEND_INTERVAL_MILLIS = 5_000;
    private static final RetryConfig RETRY_CONFIG = RetryConfig.DEFAULT
            .withDelay(10)
            .withMaxDelay(60_000)
            .withNumRetries(10);

    private static final AtomicReferenceFieldUpdater<TaskPipeline, State> stateUpdater =
            AtomicReferenceFieldUpdater.newUpdater(TaskPipeline.class, State.class, "state");

    private final String taskId;
    private final TaskMetrics.Type metrics;
    private final TaskDeps deps;
    private final PingActorRunner extendLockActor;
    private final CompletableFuture<Status> doneFuture = new CompletableFuture<>();

    private volatile State state = State.LOCK;
    private Task task;
    private long seqNo;
    private long expiredAt;

    public TaskPipeline(String taskId, TaskMetrics.Type metrics, TaskDeps deps) {
        this.taskId = taskId;
        this.metrics = metrics;
        this.deps = deps;
        this.extendLockActor = PingActorRunner.newBuilder()
                .executor(deps.executor)
                .timer(deps.timer)
                .onPing(this::extendLock)
                .pingInterval(Duration.ofMillis(LEASE_EXTEND_INTERVAL_MILLIS))
                .operation("extend task stepLock")
                .build();
    }

    public CompletableFuture<Status> start() {
        nextStep();
        return doneFuture;
    }

    private void nextStep() {
        long stepStartNanos = System.nanoTime();
        var initState = state;
        retry(() -> switch (state) {
            case LOCK -> stepLock();
            case START -> stepStart();
            case LOAD -> stepLoad();
            case RUN -> stepRun();
            case DONE -> stepDone();
        }).exceptionally(e -> {
            return Status.fromThrowable(e);
        }).thenAccept(status -> {
            var currentState = state;
            if (currentState == State.DONE) {
                doneFuture.complete(status);
                return;
            }

            if (!status.isOk()) {
                switchState(currentState, State.DONE);
                long startNanos = System.nanoTime();
                var future = retry(this::stepDone).thenApply(ignore -> status)
                        .whenComplete((ignore, e) -> {
                            metrics.spendTime(State.DONE, startNanos);
                        });
                CompletableFutures.whenComplete(future, doneFuture);
                return;
            }

            var next = State.VALUES[currentState.ordinal() + 1];
            switchState(currentState, next);
            metrics.spendTime(initState, System.nanoTime() - stepStartNanos);
            nextStep();
        });
    }

    private CompletableFuture<Status> stepLock() {
        expiredAt = deps.clock.millis() + LEASE_TIME_MILLIS;
        return deps.locksDao.acquireLock(taskId, deps.node, Instant.ofEpochMilli(expiredAt))
                .thenApply(lock -> {
                    if (!deps.node.equals(lock.owner())) {
                        logger.debug("{} {} task acquired by other process: {}", deps.node, taskId, lock);
                        return Status.ABORTED.withDescription("acquired by other thread");
                    } else {
                        logger.debug("{} {} acquired lock {}", deps.node, taskId, lock.seqNo());
                        seqNo = lock.seqNo();
                        extendLockActor.schedule();
                        return Status.OK;
                    }
                });
    }

    private CompletableFuture<Status> stepStart() {
        return deps.dao.changeState(taskId, Task.State.RUNNING, seqNo)
                .thenApply(success -> {
                    if (!success) {
                        logger.debug("{} {} unable to start task", deps.node, taskId);
                        return Status.ABORTED.withDescription("unable to start task");
                    } else {
                        return Status.OK;
                    }
                });
    }

    private CompletableFuture<Status> stepLoad() {
        return deps.dao.get(taskId).thenApply(optTask -> {
            if (optTask.isEmpty()) {
                logger.debug("{} {} task not found", deps.node, taskId);
                return Status.NOT_FOUND.withDescription("Task not found by id " + taskId);
            } else {
                task = optTask.get();
                return Status.OK;
            }
        });
    }

    private CompletableFuture<Status> stepRun() {
        if (task.executeAt() > deps.clock.millis()) {
            return retry(() -> deps.dao.changeState(taskId, Task.State.SCHEDULED, seqNo))
                    .thenApply(ignore -> Status.ABORTED.withDescription("too earlier to run task"));
        }

        logger.debug("{} {} starting task", deps.node, taskId);
        var context = new Context();
        deps.taskExecutor.execute(context);
        return context.future;
    }

    private CompletableFuture<Status> stepDone() {
        extendLockActor.close();
        if (seqNo == 0) {
            return completedFuture(Status.OK);
        }

        return deps.locksDao.releaseLock(taskId, deps.node)
                .thenApply(success -> Status.OK);
    }

    private CompletableFuture<Void> extendLock(int attempt) {
        if (state != State.RUN) {
            return completedFuture(null);
        }

        if (deps.clock.millis() + LEASE_EXTEND_INTERVAL_MILLIS >= expiredAt) {
            logger.debug("{} {} lock expired", deps.node, taskId);
            extendLockActor.close();
            doneFuture.complete(Status.FAILED_PRECONDITION.withDescription("lease expired"));
            switchState(State.RUN, State.DONE);
            return completedFuture(null);
        }

        long expiredAt = deps.clock.millis() + LEASE_TIME_MILLIS;
        return deps.locksDao.extendLockTime(taskId, deps.node, Instant.ofEpochMilli(expiredAt))
                .thenAccept(success -> {
                    if (success) {
                        this.expiredAt = expiredAt;
                        logger.debug("{} {} extended lock {}", taskId, deps.node, seqNo);
                    } else if (state == State.RUN) {
                        logger.debug("{} {} lose lock {} ownership", deps.node, taskId, seqNo);
                        extendLockActor.close();
                        doneFuture.complete(Status.FAILED_PRECONDITION.withDescription("lose lock ownership"));
                        switchState(State.RUN, State.DONE);
                    }
                });
    }

    private void switchState(State expect, State next) {
        if (stateUpdater.compareAndSet(this, expect, next)) {
            logger.debug("{} {} change state {} -> {}", deps.node, taskId, expect, next);
        }
    }

    private <T> CompletableFuture<T> retry(Supplier<CompletableFuture<T>> supplier) {
        var config = RETRY_CONFIG.withStats((timeSpentMillis, cause) -> {
            logger.warn("{} {} failed in state {}", deps.node, taskId, state);
        });

        return RetryCompletableFuture.runWithRetries(supplier, config);
    }

    enum State {
        LOCK, START, LOAD, RUN, DONE;

        private static final State[] VALUES = values();
    }

    private class Context implements ExecutionContext {
        private final CompletableFuture<Status> future = new CompletableFuture<>();

        @Override
        public Task task() {
            return task;
        }

        @Override
        public <T extends Message> CompletableFuture<?> complete(T result) {
            logger.debug("{} {} complete", deps.node, taskId);
            var packedResult = Proto.pack(result);
            return completeCall(() -> deps.dao.complete(taskId, packedResult, seqNo))
                    .thenAccept(ignore -> metrics.completeTask(Code.OK));
        }

        @Override
        public CompletableFuture<?> fail(Throwable e) {
            logger.warn("{} {} failed", deps.node, taskId, e);
            var status = Status.fromThrowable(e);
            if (Strings.isNullOrEmpty(status.getDescription())) {
                status = status.withDescription(Throwables.getStackTraceAsString(e));
            }
            var finalStatus = status;
            return completeCall(() -> deps.dao.failed(taskId, finalStatus, seqNo))
                    .thenAccept(ignore -> metrics.completeTask(finalStatus.getCode()));
        }

        @Override
        public <T extends Message> CompletableFuture<?> reschedule(long executeAt, T progress) {
            logger.debug("{} {} reschedule at {}", deps.node, taskId, executeAt);
            var packedProgress = Proto.pack(progress);
            return completeCall(() -> deps.dao.reschedule(taskId, executeAt, packedProgress, seqNo))
                    .thenAccept(ignore -> metrics.reschedule.inc());
        }

        @Override
        public <T extends Message> CompletableFuture<?> progress(T progress) {
            if (isDone()) {
                return CompletableFuture.failedFuture(Status.FAILED_PRECONDITION
                        .withDescription("Task already done")
                        .asRuntimeException());
            }

            var packedProgress = Proto.pack(progress);
            return retry(() -> deps.dao.progress(taskId, packedProgress, seqNo))
                    .thenAccept(success -> {
                        if (!success) {
                            var status = Status.FAILED_PRECONDITION.withDescription("unable update progress, task running by other process");
                            future.complete(status);
                            throw status.asRuntimeException();
                        }

                        metrics.progress.inc();
                    });
        }

        @Override
        public CompletableFuture<?> cancel() {
            return completeCall(() -> deps.dao.changeState(taskId, Task.State.SCHEDULED, seqNo))
                    .thenAccept(ignore -> metrics.cancel.inc());
        }

        private CompletableFuture<Boolean> completeCall(Supplier<CompletableFuture<Boolean>> supplier) {
            if (isDone()) {
                return CompletableFuture.failedFuture(Status.FAILED_PRECONDITION
                        .withDescription("Task already done")
                        .asRuntimeException());
            }

            return retry(supplier).whenComplete((success, e) -> {
                if (e != null) {
                    future.complete(Status.fromThrowable(e));
                } else if (success) {
                    future.complete(Status.OK);
                } else {
                    var status = Status.FAILED_PRECONDITION.withDescription("unable complete task because running by other process");
                    future.complete(status);
                    throw status.asRuntimeException();
                }
            });
        }

        @Override
        public boolean isDone() {
            return state != State.RUN || future.isDone();
        }
    }
}
