package ru.yandex.solomon.scheduler;

import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.IntStream;

import com.google.protobuf.Any;
import com.google.protobuf.Int32Value;
import com.google.protobuf.StringValue;
import io.grpc.Status;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;

import ru.yandex.misc.concurrent.CompletableFutures;
import ru.yandex.solomon.locks.dao.memory.InMemoryLocksDao;
import ru.yandex.solomon.scheduler.ProgressOperator.Fail;
import ru.yandex.solomon.scheduler.ProgressOperator.Ok;
import ru.yandex.solomon.scheduler.ProgressOperator.Stop;
import ru.yandex.solomon.scheduler.Task.State;
import ru.yandex.solomon.scheduler.dao.memory.InMemorySchedulerDao;
import ru.yandex.solomon.scheduler.handlers.ContextTaskHandler;
import ru.yandex.solomon.scheduler.handlers.CooperativeTaskHandler;
import ru.yandex.solomon.scheduler.handlers.Permits;
import ru.yandex.solomon.scheduler.handlers.Tasks;
import ru.yandex.solomon.ut.ManualClock;
import ru.yandex.solomon.ut.ManualScheduledExecutorService;
import ru.yandex.solomon.util.Proto;

import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.CompletableFuture.failedFuture;
import static java.util.stream.Collectors.collectingAndThen;
import static java.util.stream.Collectors.toList;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue;
import static ru.yandex.solomon.scheduler.handlers.Tasks.anyAsNumber;
import static ru.yandex.solomon.scheduler.handlers.Tasks.contextTask;
import static ru.yandex.solomon.scheduler.handlers.Tasks.cooperativeTask;
import static ru.yandex.solomon.scheduler.handlers.Tasks.getTask;
import static ru.yandex.solomon.scheduler.handlers.Tasks.incTask;

/**
 * @author Vladimir Gordiychuk
 */
public class TaskSchedulerImplTest {

    @Rule
    public Timeout timeout = Timeout.builder()
            .withTimeout(30, TimeUnit.SECONDS)
            .build();

    private ManualClock clock;
    private TaskExecutorStub executor;
    private InMemorySchedulerDao dao;
    private InMemoryLocksDao locksDao;
    private TaskSchedulerImpl scheduler;
    private ManualScheduledExecutorService timer;

    @Before
    public void setUp() {
        clock = new ManualClock();
        timer = new ManualScheduledExecutorService(1, clock);
        executor = new TaskExecutorStub(Tasks.handlers());
        dao = new InMemorySchedulerDao();
        locksDao = new InMemoryLocksDao(clock);
        var deps = new TaskDeps("host", executor, dao, locksDao, clock, ForkJoinPool.commonPool(), timer);
        scheduler = new TaskSchedulerImpl(Duration.ofSeconds(15), 10_000, deps);
    }

    @After
    public void tearDown() {
        Permits.clear();
        if (scheduler != null) {
            scheduler.close();
        }

        if (timer != null) {
            timer.shutdownNow();
        }
    }

    @Test
    public void scheduleAndRun() throws InterruptedException {
        assertEquals(1, runNumber(incTask()));
    }

    @Test
    public void scheduleSequential() throws InterruptedException {
        for (int index = 0; index < 10; index++) {
            assertEquals(index + 1, runNumber(incTask()));
        }
    }

    @Test
    public void scheduleOnlyOnceSequential() throws InterruptedException {
        var inc = incTask();
        assertEquals(1, runNumber(inc));
        assertEquals(1, runNumber(inc));
        assertEquals(1, runNumber(inc));
        assertEquals(1, runNumber(getTask()));

        assertEquals(2, runNumber(incTask()));
        assertEquals(2, runNumber(getTask()));
    }

    @Test
    public void scheduleOnlyOnceParallel() throws InterruptedException {
        var inc = incTask();
        IntStream.range(0, 100)
                .parallel()
                .mapToObj(ignore -> schedule(inc))
                .collect(collectingAndThen(toList(), CompletableFutures::allOfVoid))
                .join();

        waitComplete(inc.id());
        assertEquals(1, taskResultAsNumber(inc.id()));
        assertEquals(1, runNumber(getTask()));
    }

    @Test
    public void parallelTaskExecution() throws InterruptedException {
        var tasks = IntStream.range(0, 100)
                .mapToObj(ignore -> incTask())
                .collect(toList());

        tasks.parallelStream()
                .map(this::schedule)
                .collect(collectingAndThen(toList(), CompletableFutures::allOfVoid))
                .join();

        for (var task : tasks) {
            waitComplete(task.id());
        }

        var results = tasks.stream()
                .mapToInt(task -> taskResultAsNumber(task.id()))
                .sorted()
                .toArray();

        assertArrayEquals(IntStream.rangeClosed(1, 100).toArray(), results);
        assertEquals(100, runNumber(getTask()));
    }

    @Test
    public void scheduleTask() throws InterruptedException {
        var now = clock.millis();
        var delayTask = incTask(now + TimeUnit.HOURS.toMillis(1));
        schedule(delayTask).join();
        assertFalse(waitComplete(delayTask.id(), 1, TimeUnit.MILLISECONDS));

        // get task happens before scheduled
        assertEquals(0, runNumber(getTask()));
        assertFalse(waitComplete(delayTask.id(), 1, TimeUnit.MILLISECONDS));

        clock.passedTime(1, TimeUnit.HOURS);
        while (!waitComplete(delayTask.id(), 1, TimeUnit.MILLISECONDS)) {
            clock.passedTime(1, TimeUnit.SECONDS);
        }

        assertEquals(1, taskResultAsNumber(delayTask.id()));
        assertEquals(1, runNumber(getTask()));
    }

    @Test
    public void retryDaoError() throws InterruptedException {
        var availableError = new AtomicInteger(3);
        beforeDao(() -> {
            if (ThreadLocalRandom.current().nextBoolean() || availableError.get() == 0) {
                return completedFuture(null);
            }

            availableError.decrementAndGet();
            return failedFuture(new RuntimeException("hi"));
        });

        for (int index = 1; index <= 5; index++) {
            var task = incTask();
            assertTrue(dao.syncAdd(task));

            scheduler.forceAct();
            waitComplete(task.id());

            assertEquals(index, taskResultAsNumber(task.id()));
        }
    }

    @Test
    public void checkTaskToRunByDao() throws InterruptedException {
        var task = incTask();
        assertTrue(dao.syncAdd(task));

        while (!waitComplete(task.id(), 1, TimeUnit.MILLISECONDS)) {
            clock.passedTime(1, TimeUnit.SECONDS);
        }

        assertEquals(1, taskResultAsNumber(task.id()));
    }

    @Test
    public void rescheduleTask() throws InterruptedException {
        var task = contextTask();
        schedule(task).join();

        var context = ContextTaskHandler.contextByTaskId(task.id()).join();
        ContextTaskHandler.clear();
        context.reschedule(clock.millis() + 5_000, Any.getDefaultInstance()).join();

        var sync = new CountDownLatch(1);
        var rescheduleContextFuture = ContextTaskHandler.contextByTaskId(task.id()).whenComplete((ignore, e) -> sync.countDown());
        while (!sync.await(1, TimeUnit.MILLISECONDS)) {
            clock.passedTime(1, TimeUnit.SECONDS);
        }
        var rescheduleContext = rescheduleContextFuture.join();
        rescheduleContext.complete(anyAsNumber(42)).join();
        waitComplete(task.id());

        assertEquals(42, taskResultAsNumber(task.id()));
    }

    @Test
    public void releasePermitWhenComplete() throws InterruptedException {
        Permits.INSTANCE = new Permits(2, new CountDownLatch(1));

        var task = contextTask();
        schedule(task).join();

        var context = ContextTaskHandler.contextByTaskId(task.id()).join();
        assertEquals(1, Permits.INSTANCE.available());

        context.complete(anyAsNumber(1)).join();

        Permits.INSTANCE.onRelease.await();
        assertEquals(2, Permits.INSTANCE.available());

        waitComplete(task.id());
        assertEquals(1, taskResultAsNumber(task.id()));
    }

    @Test
    public void avoidStartWithoutPermit() throws InterruptedException {
        Permits.INSTANCE = new Permits(0, new CountDownLatch(1));

        var task = incTask();
        schedule(task).join();

        {
            int cnt = 10;
            while (!waitComplete(task.id(), 1, TimeUnit.MILLISECONDS) && cnt-- > 0) {
                clock.passedTime(1, TimeUnit.SECONDS);
            }
        }

        // task still not started because no permit
        Permits.INSTANCE = new Permits(1, new CountDownLatch(1));
        while (!waitComplete(task.id(), 1, TimeUnit.MILLISECONDS)) {
            clock.passedTime(1, TimeUnit.SECONDS);
        }

        assertEquals(1, taskResultAsNumber(task.id()));
    }

    @Test
    public void sequentialByOnePermit() throws InterruptedException {
        Permits.INSTANCE = new Permits(1, new CountDownLatch(100));
        var tasks = IntStream.range(0, 100)
                .mapToObj(ignore -> incTask())
                .collect(toList());

        tasks.forEach(task -> dao.syncAdd(task));
        scheduler.forceAct();

        Permits.INSTANCE.onRelease.await();
        assertEquals(100, runNumber(getTask()));
    }

    @Test
    public void avoidStartAlreadyRunningByAnotherNode() throws InterruptedException {
        var task = incTask();
        locksDao.acquireLock(task.id(), "another_node", clock.instant().plus(1, ChronoUnit.HOURS)).join();
        assertTrue(dao.syncAdd(task));

        scheduler.forceAct();

        int cnt = 10;
        while (!waitComplete(task.id(), 1, TimeUnit.MILLISECONDS) && cnt-- > 0) {
            clock.passedTime(1, TimeUnit.SECONDS);
        }

        var notCompeted = dao.syncGet(task.id()).orElseThrow();
        assertNotEquals(State.COMPLETED, notCompeted.state());

        // previous lock expired, task can be executed by current node
        clock.passedTime(1, TimeUnit.HOURS);
        while (!waitComplete(task.id(), 1, TimeUnit.MILLISECONDS)) {
            clock.passedTime(1, TimeUnit.SECONDS);
        }

        assertEquals(1, taskResultAsNumber(task.id()));
    }

    @Test
    public void rescheduleWhenScheduled() throws Exception {
        var task = cooperativeTask(clock.millis() + TimeUnit.MINUTES.toMillis(1));
        schedule(task).join();
        assertEquals(State.SCHEDULED, scheduler.getTask(task.id()).join().orElseThrow().state());
        assertTrue(taskProgressAsString(task.id()).isEmpty());

        var interruption = Proto.pack(StringValue.of(CooperativeTaskHandler.INTERRUPTED));
        assertTrue(scheduler.reschedule(task.id(), clock.millis() + 3_000, p -> new Ok(interruption)).join());

        do {
            clock.passedTime(1, TimeUnit.SECONDS);
        } while (!waitComplete(task.id(), 1, TimeUnit.MILLISECONDS));

        assertEquals(CooperativeTaskHandler.CANCELED, taskResultAsString(task.id()));
    }

    @Test
    public void rescheduleWhenRunning() throws Exception {
        var task = cooperativeTask();
        schedule(task).join();
        while (true) {
            if (!taskProgressAsString(task.id()).isEmpty()) {
                break;
            }
        }

        var interruption = Proto.pack(StringValue.of(CooperativeTaskHandler.INTERRUPTED));
        assertTrue(scheduler.reschedule(task.id(), clock.millis() + 3_000, p -> new Ok(interruption)).join());

        do {
            clock.passedTime(1, TimeUnit.SECONDS);
        } while (!waitComplete(task.id(), 1, TimeUnit.MILLISECONDS));

        assertEquals(CooperativeTaskHandler.CANCELED, taskResultAsString(task.id()));
    }

    @Test
    public void rescheduleWhenCompleted() throws Exception {
        var workToDo = 100;
        var task = cooperativeTask(clock.millis() + TimeUnit.SECONDS.toMillis(1), workToDo);
        schedule(task).join();
        do {
            clock.passedTime(1, TimeUnit.SECONDS);
        } while (!waitComplete(task.id(), 1, TimeUnit.MILLISECONDS));

        var interruption = Proto.pack(StringValue.of(CooperativeTaskHandler.INTERRUPTED));
        assertFalse(scheduler.reschedule(task.id(), clock.millis() + 3_000, p -> new Ok(interruption)).join());

        assertEquals(CooperativeTaskHandler.DONE, taskResultAsString(task.id()));
    }

    @Test
    public void rescheduleWhenNotFound() {
        var status = scheduler.reschedule("meh", 42, ProgressOperator.identity())
            .thenApply(i -> Status.OK)
            .exceptionally(Status::fromThrowable)
            .join();

        assertEquals(Status.Code.NOT_FOUND, status.getCode());
    }

    @Test
    public void rescheduleWhenContended() {
        var task = contextTask();
        schedule(task).join();

        ProgressOperator increment = progress -> {
            var prev = Proto.unpack(progress, Int32Value.getDefaultInstance());
            var update = prev.toBuilder().setValue(prev.getValue() + 1).build();
            return new Ok(Proto.pack(update));
        };
        IntStream.range(0, 8)
            .parallel()
            .mapToObj(i -> scheduler.reschedule(task.id(), clock.millis(), increment))
            .collect(collectingAndThen(toList(), CompletableFutures::allOf))
            .join();

        assertEquals(8, taskProgressAsNumber(task.id()));
    }

    @Test
    public void rescheduleWhenOperatorWantsToStop() throws InterruptedException {
        var workToDo = 100;
        var task = cooperativeTask(clock.millis(), workToDo);
        schedule(task).join();

        ProgressOperator stopOperator = p -> new Stop();
        assertFalse(scheduler.reschedule(task.id(), clock.millis(), stopOperator).join());

        do {
            clock.passedTime(1, TimeUnit.SECONDS);
        } while (!waitComplete(task.id(), 1, TimeUnit.MILLISECONDS));

        assertEquals(CooperativeTaskHandler.DONE, taskResultAsString(task.id()));
    }

    @Test
    public void causeExceptionWhenOperatorWantsToFail() throws InterruptedException {
        var workToDo = 100;
        var task = cooperativeTask(clock.millis(), workToDo);
        schedule(task).join();

        var epicFail = Status.DATA_LOSS.withDescription("epic fail");
        ProgressOperator failOperator = p -> new Fail(epicFail.asRuntimeException());
        var status = scheduler.reschedule(task.id(), clock.millis(), failOperator)
            .thenApply(i -> Status.OK)
            .exceptionally(Status::fromThrowable)
            .join();

        assertEquals(epicFail, status);

        do {
            clock.passedTime(1, TimeUnit.SECONDS);
        } while (!waitComplete(task.id(), 1, TimeUnit.MILLISECONDS));

        assertEquals(CooperativeTaskHandler.DONE, taskResultAsString(task.id()));
    }

    private void beforeDao(Supplier<CompletableFuture<?>> supplier) {
        dao.beforeSupplier = supplier;
        locksDao.beforeSupplier = supplier;
    }

    private int runNumber(Task task) throws InterruptedException {
        schedule(task).join();
        waitComplete(task.id());
        return taskResultAsNumber(task.id());
    }

    private CompletableFuture<Void> schedule(Task task) {
        return scheduler.schedule(task);
    }

    private void waitComplete(String taskId) throws InterruptedException {
        assertTrue(waitComplete(taskId, 30, TimeUnit.SECONDS));
    }

    private boolean waitComplete(String taskId, int time, TimeUnit unit) throws InterruptedException {
        long deadline = System.nanoTime() + unit.toNanos(time);
        while (true) {
            var sync = dao.updateSync(taskId);
            var optTask = dao.syncGet(taskId);
            assertTrue(optTask.isPresent());
            if (optTask.get().state() == State.COMPLETED) {
                return true;
            }

            long waitNanos = deadline - System.nanoTime();
            if (waitNanos <= 0) {
                return false;
            }

            if (!sync.await(waitNanos, TimeUnit.NANOSECONDS)) {
                return false;
            }
        }
    }

    private int taskResultAsNumber(String taskId) {
        return taskResult(taskId, result -> Proto.unpack(result, Int32Value.class).getValue());
    }

    private String taskResultAsString(String taskId) {
        return taskResult(taskId, result -> Proto.unpack(result, StringValue.class).getValue());
    }

    private <R> R taskResult(String taskId, Function<Any, R> fn) {
        var optTask = dao.syncGet(taskId);
        assertTrue(optTask.isPresent());
        var task = optTask.get();
        assertEquals(State.COMPLETED, task.state());
        return fn.apply(task.result());
    }

    private String taskProgressAsString(String taskId) {
        return taskProgress(taskId, progress -> Proto.unpack(progress, StringValue.getDefaultInstance()).getValue());
    }

    private int taskProgressAsNumber(String taskId) {
        return taskProgress(taskId, progress -> Proto.unpack(progress, Int32Value.getDefaultInstance()).getValue());
    }

    private <R> R taskProgress(String taskId, Function<Any, R> fn) {
        var task = dao.syncGet(taskId);
        assertTrue(task.isPresent());
        var progress = task.orElseThrow().progress();
        return fn.apply(progress);
    }
}
