package ru.yandex.solomon.scheduler.dao.memory;

import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.IntPredicate;
import java.util.function.LongPredicate;
import java.util.function.LongUnaryOperator;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import javax.annotation.ParametersAreNonnullByDefault;

import com.google.common.collect.Lists;
import com.google.protobuf.Any;
import io.grpc.Status;

import ru.yandex.solomon.scheduler.ScheduledTask;
import ru.yandex.solomon.scheduler.Task;
import ru.yandex.solomon.scheduler.Task.State;
import ru.yandex.solomon.scheduler.dao.SchedulerDao;

import static java.util.function.Predicate.not;

/**
 * @author Vladimir Gordiychuk
 */
@ParametersAreNonnullByDefault
public class InMemorySchedulerDao implements SchedulerDao {
    private final ConcurrentMap<String, Record> taskById = new ConcurrentHashMap<>();
    public final ConcurrentMap<String, CountDownLatch> updateSyncById = new ConcurrentHashMap<>();
    public volatile Supplier<CompletableFuture<?>> beforeSupplier;

    @Override
    public CompletableFuture<Void> createSchema() {
        return CompletableFuture.completedFuture(null);
    }

    @Override
    public CompletableFuture<Boolean> add(Task task) {
        return before().thenApplyAsync(o -> syncAdd(task));
    }

    public boolean syncAdd(Task task) {
        return taskById.putIfAbsent(task.id(), new Record(task, 0)) == null;
    }

    @Override
    public CompletableFuture<Optional<Task>> get(String taskId) {
        return before().thenApplyAsync(o -> syncGet(taskId));
    }

    public Optional<Task> syncGet(String taskId) {
        return Optional.ofNullable(taskById.get(taskId)).map(Record::task);
    }

    @Override
    public CompletableFuture<Boolean> complete(String taskId, Any result, long seqNo) {
        return before().thenApplyAsync(o -> {
            return update(taskId, seqNo, not(State.COMPLETED::equals), task -> task.toBuilder()
                    .setResult(result)
                    .setState(State.COMPLETED)
                    .build());
        });
    }

    @Override
    public CompletableFuture<Boolean> failed(String taskId, Status status, long seqNo) {
        return before().thenApplyAsync(o -> {
            return update(taskId, seqNo, not(State.COMPLETED::equals), task -> task.toBuilder()
                    .setStatus(status)
                    .setState(State.COMPLETED)
                    .build());
        });
    }

    @Override
    public CompletableFuture<Boolean> reschedule(String taskId, long executeAt, Any progress, long seqNo) {
        return before().thenApplyAsync(o -> {
            return update(taskId, seqNo, not(State.COMPLETED::equals), task -> task.toBuilder()
                    .setExecuteAt(executeAt)
                    .setProgress(progress)
                    .setVersion(task.version() + 1)
                    .setState(State.SCHEDULED)
                    .build());
        });
    }

    @Override
    public CompletableFuture<Boolean> rescheduleExternally(
        String taskId,
        long executeAt,
        Any progress,
        int expectedVersion)
    {
        return before().thenApplyAsync(o -> {
            return update(taskId, expectedVersion, seqNo -> seqNo + 1, not(State.COMPLETED::equals), task -> task.toBuilder()
                .setExecuteAt(executeAt)
                .setProgress(progress)
                .setVersion(task.version() + 1)
                .build());
        });
    }

    @Override
    public CompletableFuture<Boolean> progress(String taskId, Any progress, long seqNo) {
        return before().thenApplyAsync(o -> {
            return update(taskId, seqNo, State.RUNNING::equals, task -> task.toBuilder()
                    .setProgress(progress)
                    .setVersion(task.version() + 1)
                    .build());
        });
    }

    @Override
    public CompletableFuture<Void> list(Consumer<Task> consumer) {
        return before(() -> {
            var root = new CompletableFuture<Void>();
            var future = root;
            var copy = taskById.values().stream().map(Record::task).collect(Collectors.toList());
            int maxBatchSize = Math.max(copy.size() / 3, 1001);
            int batchSize = ThreadLocalRandom.current().nextInt(1000, maxBatchSize);
            for (var batch : Lists.partition(copy, batchSize)) {
                future = future.thenCompose(size -> before(() -> {
                    for (var task : batch) {
                        consumer.accept(task);
                    }
                    return null;
                }));
            }
            root.complete(null);
            return future;
        }).thenCompose(future -> future);
    }

    @Override
    public CompletableFuture<Boolean> changeState(String taskId, State state, long seqNo) {
        return before().thenApplyAsync(o -> {
            return update(taskId, seqNo, not(State.COMPLETED::equals), task -> task.toBuilder()
                    .setState(state)
                    .build());
        });
    }

    @Override
    public CompletableFuture<List<ScheduledTask>> listScheduled(long now, int limit) {
        return before().thenApplyAsync(o -> {
            return taskById.values()
                    .stream()
                    .map(Record::task)
                    .filter(task -> {
                        if (task.executeAt() > now) {
                            return false;
                        }

                        var state = task.state();
                        return state == State.RUNNING || state == State.SCHEDULED;
                    })
                    .limit(limit)
                    .map(task -> new ScheduledTask(task.executeAt(), task.id(), task.type(), task.params()))
                    .collect(Collectors.toList());
        });
    }

    public CountDownLatch updateSync(String taskId) {
        return updateSyncById.compute(taskId, (s, prev) -> {
            if (prev == null || prev.getCount() == 0) {
                return new CountDownLatch(1);
            }

            return prev;
        });
    }

    private boolean update(
        String taskId,
        long seqNo,
        Predicate<State> statePredicate,
        Function<Task, Task> fn)
    {
        return update(
            taskId,
            taskSeqNo -> taskSeqNo <= seqNo,
            i -> seqNo,
            always -> true,
            statePredicate,
            fn);
    }

    private boolean update(
        String taskId,
        int version,
        LongUnaryOperator seqNoOp,
        Predicate<State> statePredicate,
        Function<Task, Task> fn)
    {
        return update(
            taskId,
            always -> true,
            seqNoOp,
            taskVersion -> taskVersion == version,
            statePredicate,
            fn);
    }

    private boolean update(
        String taskId,
        LongPredicate seqNoPredicate,
        LongUnaryOperator seqNoOp,
        IntPredicate versionPredicate,
        Predicate<State> statePredicate,
        Function<Task, Task> fn)
    {
        var record = taskById.get(taskId);
        if (record == null) {
            return Boolean.FALSE;
        }

        if (!seqNoPredicate.test(record.seqNo)) {
            return Boolean.FALSE;
        }

        if (!versionPredicate.test(record.task.version())) {
            return Boolean.FALSE;
        }

        if (!statePredicate.test(record.task.state())) {
            return Boolean.FALSE;
        }

        var newTask = fn.apply(record.task);
        var newSeqNo = seqNoOp.applyAsLong(record.seqNo);
        boolean success = taskById.replace(taskId, record, new Record(newTask, newSeqNo));
        if (success) {
            updateSync(taskId).countDown();
        }
        return success;
    }

    private <T> CompletableFuture<T> before(Supplier<T> fn) {
        return before().thenApplyAsync(ignore -> fn.get());
    }

    private CompletableFuture<?> before() {
        var copy = beforeSupplier;
        if (copy == null) {
            return CompletableFuture.completedFuture(null);
        }

        return copy.get();
    }

    private static record Record(Task task, long seqNo) {
    }
}
