package ru.yandex.solomon.scheduler;

import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import com.google.common.collect.Iterables;
import com.google.protobuf.Any;
import com.google.protobuf.Int32Value;
import io.grpc.Status;
import io.grpc.Status.Code;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;

import ru.yandex.solomon.locks.dao.memory.InMemoryLocksDao;
import ru.yandex.solomon.scheduler.Task.State;
import ru.yandex.solomon.scheduler.dao.memory.InMemorySchedulerDao;
import ru.yandex.solomon.scheduler.handlers.ContextTaskHandler;
import ru.yandex.solomon.scheduler.handlers.FutureTaskHandler;
import ru.yandex.solomon.scheduler.handlers.Tasks;
import ru.yandex.solomon.ut.ManualClock;
import ru.yandex.solomon.ut.ManualScheduledExecutorService;
import ru.yandex.solomon.util.Proto;
import ru.yandex.solomon.util.host.HostUtils;

import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.CompletableFuture.failedFuture;
import static org.hamcrest.Matchers.greaterThan;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static ru.yandex.solomon.scheduler.handlers.Tasks.anyAsNumber;
import static ru.yandex.solomon.scheduler.handlers.Tasks.contextTask;
import static ru.yandex.solomon.scheduler.handlers.Tasks.futureTask;
import static ru.yandex.solomon.scheduler.handlers.Tasks.getTask;
import static ru.yandex.solomon.scheduler.handlers.Tasks.incTask;

/**
 * @author Vladimir Gordiychuk
 */
public class TaskPipelineTest {

    @Rule
    public Timeout timeout = Timeout.builder()
            .withLookingForStuckThread(true)
            .withTimeout(30, TimeUnit.SECONDS)
            .build();

    private ManualClock clock;
    private ManualScheduledExecutorService timer;
    private TaskExecutorStub taskExecutor;
    private InMemorySchedulerDao dao;
    private InMemoryLocksDao locksDao;

    @Before
    public void setUp() {
        clock = new ManualClock();
        timer = new ManualScheduledExecutorService(1, clock);
        taskExecutor = new TaskExecutorStub(Tasks.handlers());
        dao = new InMemorySchedulerDao();
        locksDao = new InMemoryLocksDao(clock);
    }

    @After
    public void tearDown() {
        timer.shutdownNow();
    }

    @Test
    public void taskNotExists() {
        var status = runPipeline("not_exists_task").join();
        assertEquals(status.toString(), Code.ABORTED, status.getCode());
    }

    @Test
    public void taskLockedByOther() {
        var task = incTask();
        assertTrue(dao.add(task).join());

        var lock = locksDao.acquireLock(task.id(), "alice", clock.instant().plusSeconds(30)).join();
        assertNotEquals(0L, lock.seqNo());

        var status = runPipeline(task.id()).join();
        assertEquals(status.toString(), Code.ABORTED, status.getCode());
        assertEquals("skip evaluation", 0, (int) taskExecutor.runNumber(getTask()).join());
    }

    @Test
    public void taskRun() {
        var task = incTask();
        assertTrue(dao.add(task).join());

        var status = runPipeline(task.id()).join();
        assertEquals(status.toString(), Code.OK, status.getCode());
        assertEquals(1, taskResultAsNumber(task.id()));
    }

    @Test
    public void saveSuccessResult() {
        for (int index = 1; index <= 3; index++) {
            var task = incTask();
            assertTrue(dao.add(task).join());

            var status = runPipeline(task.id()).join();
            assertEquals(status.toString(), Code.OK, status.getCode());
            assertEquals(index, taskResultAsNumber(task.id()));
        }
    }

    @Test
    public void taskLockedWhenItRun() {
        var task = futureTask();
        assertTrue(dao.add(task).join());

        var futures = Stream.of("alice", "bob")
                .parallel()
                .map(node -> runPipeline(task.id(), node))
                .collect(Collectors.toList());

        // one of process should be complete earlier because other lock task
        {
            CompletableFuture.anyOf(futures.toArray(new CompletableFuture[0])).join();

            var failed = futures.stream()
                    .filter(CompletableFuture::isDone)
                    .findFirst()
                    .get()
                    .join();

            assertEquals(failed.toString(), Code.ABORTED, failed.getCode());
        }

        // complete success other task
        {
            var future = futures.stream()
                    .filter(Predicate.not(CompletableFuture::isDone))
                    .findFirst()
                    .get();

            FutureTaskHandler.futureByTaskId(task.id())
                    .complete(Any.pack(Int32Value.of(42)));

            var status = future.join();
            assertEquals(status.toString(), Code.OK, status.getCode());
        }
    }

    @Test
    public void leaseExpired() throws InterruptedException {
        var task = futureTask();
        assertTrue(dao.add(task).join());

        CountDownLatch sync = new CountDownLatch(1);
        var future = runPipeline(task.id()).whenComplete((status, throwable) -> sync.countDown());

        do {
            clock.passedTime(5, TimeUnit.MINUTES);
        } while (!sync.await(1, TimeUnit.MILLISECONDS));

        var status = future.join();
        assertEquals(status.toString(), Code.FAILED_PRECONDITION, status.getCode());
    }

    @Test
    public void extendLeaseDuringEvaluation() throws InterruptedException {
        var task = futureTask();
        assertTrue(dao.add(task).join());

        CountDownLatch sync = new CountDownLatch(1);
        var future = runPipeline(task.id()).whenComplete((ignore, e) -> sync.countDown());

        // await when task become running
        while (locksDao.listLocks().join().isEmpty()) {
            TimeUnit.MILLISECONDS.sleep(1);
        }

        // task present in running list
        var initial = Iterables.getOnlyElement(locksDao.listLocks().join());
        assertEquals(task.id(), initial.id());
        assertThat(initial.expiredAt(), greaterThan(clock.instant()));

        // 1ms == 1s, task working more then 40 sec in total
        for (int i = 0; i < 40; i++) {
            clock.passedTime(1, TimeUnit.SECONDS);
            assertFalse(sync.await(1, TimeUnit.MILLISECONDS));
        }

        // lease continue extend for running task
        var last = Iterables.getOnlyElement(locksDao.listLocks().join());
        assertEquals(task.id(), last.id());
        assertThat(last.expiredAt(), greaterThan(clock.instant()));
        assertThat(last.expiredAt(), greaterThan(initial.expiredAt()));

        // finish task and check result
        FutureTaskHandler.futureByTaskId(task.id())
                .complete(Any.pack(Int32Value.of(42)));

        var status = future.join();
        assertEquals(status.toString(), Code.OK, status.getCode());
        assertEquals(42, taskResultAsNumber(task.id()));
    }

    @Test
    public void leaseReplaced() throws InterruptedException {
        var task = futureTask();
        assertTrue(dao.add(task).join());

        CountDownLatch sync = new CountDownLatch(1);
        var future = runPipeline(task.id(), "alice").whenComplete((status, throwable) -> sync.countDown());

        // await when task become running
        while (locksDao.listLocks().join().isEmpty()) {
            TimeUnit.MILLISECONDS.sleep(1);
        }

        // task present in running list
        var lock = Iterables.getOnlyElement(locksDao.listLocks().join());
        assertEquals(task.id(), lock.id());
        assertEquals("alice", lock.owner());
        assertThat(lock.expiredAt(), greaterThan(clock.instant()));

        assertTrue(locksDao.releaseLock(task.id(), "alice").join());

        var bobLock = locksDao.acquireLock(task.id(), "bob", clock.instant().plus(1, ChronoUnit.HOURS)).join();
        assertEquals("bob", bobLock.owner());

        do {
            clock.passedTime(1, TimeUnit.SECONDS);
        } while (!sync.await(1, TimeUnit.MILLISECONDS));

        var status = future.join();
        assertEquals(status.toString(), Code.FAILED_PRECONDITION, status.getCode());

        FutureTaskHandler.futureByTaskId(task.id())
                .complete(Any.pack(Int32Value.of(42)));

        var resultTask = dao.get(task.id()).join().get();
        assertNotEquals(State.COMPLETED, resultTask.state());
        assertEquals(Any.getDefaultInstance(), resultTask.result());
    }

    @Test
    public void retryDaoError() {
        for (int index = 1; index <= 5; index++) {
            var task = incTask();
            assertTrue(dao.syncAdd(task));

            AtomicInteger errorCount = new AtomicInteger();
            beforeDao(() -> {
                if (ThreadLocalRandom.current().nextBoolean() && errorCount.incrementAndGet() < 5) {
                    return failedFuture(new RuntimeException("hi"));
                }
                return completedFuture(null);
            });

            var status = runPipeline(task.id()).join();
            assertEquals(status.toString(), Code.OK, status.getCode());
            assertEquals(index, taskResultAsNumber(task.id()));
        }
    }

    @Test
    public void stopRetry() {
        var task = incTask();
        assertTrue(dao.add(task).join());

        beforeDao(() -> {
            throw Status.INVALID_ARGUMENT.withDescription("hi").asRuntimeException();
        });

        var status = runPipeline(task.id()).join();
        assertEquals(status.toString(), Code.INVALID_ARGUMENT, status.getCode());
    }

    @Test
    public void unableRunTaskTwice() {
        var task = incTask();
        assertTrue(dao.add(task).join());

        {
            var status = runPipeline(task.id()).join();
            assertEquals(status.toString(), Code.OK, status.getCode());
            assertEquals(1, taskResultAsNumber(task.id()));
        }

        // to guarantee that all locks expired
        clock.passedTime(1, TimeUnit.HOURS);
        {
            var status = runPipeline(task.id()).join();
            assertEquals(status.toString(), Code.ABORTED, status.getCode());
            assertEquals(1, taskResultAsNumber(task.id()));
        }
    }

    @Test
    public void releaseLockAfterComplete() {
        var task = incTask();
        assertTrue(dao.add(task).join());

        {
            var status = runPipeline(task.id()).join();
            assertEquals(status.toString(), Code.OK, status.getCode());
            assertEquals(1, taskResultAsNumber(task.id()));
            assertEquals(List.of(), locksDao.listLocks().join());
        }

        {
            // run it again, acquired lock should be released
            var status = runPipeline(task.id()).join();
            assertEquals(status.toString(), Code.ABORTED, status.getCode());
            assertEquals(1, taskResultAsNumber(task.id()));
            assertEquals(List.of(), locksDao.listLocks().join());
        }
    }

    @Test
    public void completeTaskWithFailStatus() {
        var task = futureTask();
        assertTrue(dao.add(task).join());

        var expectStatus = Status.PERMISSION_DENIED.withDescription("expected error");

        FutureTaskHandler.futureByTaskId(task.id())
                .completeExceptionally(expectStatus.asRuntimeException());

        var status = runPipeline(task.id()).join();
        assertEquals(status.toString(), Code.OK, status.getCode());

        var result = dao.get(task.id()).join().orElseThrow();
        assertEquals(State.COMPLETED, result.state());
        assertEquals(expectStatus, result.status());
    }

    @Test
    public void rescheduleTask() {
        var task = contextTask();
        assertTrue(dao.add(task).join());

        var future = runPipeline(task.id());
        var context = ContextTaskHandler.contextByTaskId(task.id()).join();

        assertFalse(future.isDone());
        context.reschedule(clock.millis() + 1, Any.getDefaultInstance()).join();

        var status = future.join();
        assertEquals(status.toString(), Code.OK, status.getCode());

        var result = dao.get(task.id()).join().orElseThrow();
        assertEquals(State.SCHEDULED, result.state());
        assertEquals(Any.getDefaultInstance(), result.result());
    }

    @Test
    public void rescheduleAnRun() {
        var task = contextTask();
        assertTrue(dao.add(task).join());

        // rescheduling
        {
            var future = runPipeline(task.id());
            var context = ContextTaskHandler.contextByTaskId(task.id()).join();

            context.reschedule(clock.millis(), Any.getDefaultInstance()).join();

            var status = future.join();
            assertEquals(status.toString(), Code.OK, status.getCode());
        }

        // run rescheduled
        {
            ContextTaskHandler.clear();
            var future = runPipeline(task.id());
            var context = ContextTaskHandler.contextByTaskId(task.id()).join();

            context.complete(anyAsNumber(55)).join();
            var status = future.join();
            assertEquals(status.toString(), Code.OK, status.getCode());

            assertEquals(55, taskResultAsNumber(task.id()));
        }
    }

    @Test
    public void unableCompleteTwice() {
        var task = contextTask();
        assertTrue(dao.add(task).join());

        var future = runPipeline(task.id());
        var context = ContextTaskHandler.contextByTaskId(task.id()).join();

        var statusOne = context.complete(anyAsNumber(1)).thenApply(o -> Status.OK).exceptionally(Status::fromThrowable).join();
        assertEquals(statusOne.toString(), Code.OK, statusOne.getCode());

        var statusTwo = context.complete(anyAsNumber(2)).thenApply(o -> Status.OK).exceptionally(Status::fromThrowable).join();
        assertEquals(statusTwo.toString(), Code.FAILED_PRECONDITION, statusTwo.getCode());

        var status = future.join();
        assertEquals(status.toString(), Code.OK, status.getCode());

        assertEquals(1, taskResultAsNumber(task.id()));
    }

    @Test
    public void unableCompleteRescheduledTask() {
        var task = contextTask();
        assertTrue(dao.add(task).join());

        var future = runPipeline(task.id());
        var context = ContextTaskHandler.contextByTaskId(task.id()).join();

        var statusOne = context.reschedule(clock.millis() + 1, Any.getDefaultInstance()).thenApply(o -> Status.OK).exceptionally(Status::fromThrowable).join();
        assertEquals(statusOne.toString(), Code.OK, statusOne.getCode());

        var statusTwo = context.complete(anyAsNumber(42)).thenApply(o -> Status.OK).exceptionally(Status::fromThrowable).join();
        assertEquals(statusTwo.toString(), Code.FAILED_PRECONDITION, statusTwo.getCode());

        var status = future.join();
        assertEquals(status.toString(), Code.OK, status.getCode());

        var result = dao.get(task.id()).join().orElseThrow();
        assertEquals(State.SCHEDULED, result.state());
        assertEquals(Any.getDefaultInstance(), result.result());
    }

    @Test
    public void saveProgress() {
        var task = contextTask();
        assertTrue(dao.add(task).join());

        var future = runPipeline(task.id());
        var context = ContextTaskHandler.contextByTaskId(task.id()).join();

        context.progress(anyAsNumber(1)).join();
        assertEquals(1, taskProgressAsNumber(task.id()));

        context.progress(anyAsNumber(2)).join();
        assertEquals(2, taskProgressAsNumber(task.id()));

        context.complete(anyAsNumber(42)).join();

        // unable update progress for completed task
        var progressStatus = context.progress(anyAsNumber(3)).thenApply(o -> Status.OK).exceptionally(Status::fromThrowable).join();
        assertEquals(progressStatus.toString(), Code.FAILED_PRECONDITION, progressStatus.getCode());

        var status = future.join();
        assertEquals(status.toString(), Code.OK, status.getCode());
        assertEquals(42, taskResultAsNumber(task.id()));
        assertEquals(2, taskProgressAsNumber(task.id()));
    }

    @Test
    public void rescheduleAlsoSaveProgress() {
        var task = contextTask();
        assertTrue(dao.add(task).join());

        var future = runPipeline(task.id());
        var context = ContextTaskHandler.contextByTaskId(task.id()).join();

        var executedAt = clock.millis() + 10_000;
        context.reschedule(executedAt, anyAsNumber(32)).join();

        var status = future.join();
        assertEquals(status.toString(), Code.OK, status.getCode());
        var result = dao.get(task.id()).join().orElseThrow();
        assertEquals(State.SCHEDULED, result.state());
        assertEquals(anyAsNumber(32), result.progress());
        assertEquals(executedAt, result.executeAt());
    }

    @Test
    public void contextDoneWhenCompleted() {
        var task = contextTask();
        assertTrue(dao.add(task).join());

        runPipeline(task.id());
        var context = ContextTaskHandler.contextByTaskId(task.id()).join();

        assertFalse(context.isDone());
        context.complete(anyAsNumber(42)).join();

        assertTrue(context.isDone());
    }

    @Test
    public void contextDoneWhenReschedule() {
        var task = contextTask();
        assertTrue(dao.add(task).join());

        runPipeline(task.id());
        var context = ContextTaskHandler.contextByTaskId(task.id()).join();

        assertFalse(context.isDone());
        context.reschedule(clock.millis() + 100, anyAsNumber(123)).join();

        assertTrue(context.isDone());
    }

    @Test
    public void contextDoneWhenLeaseExpired() throws InterruptedException {
        var task = contextTask();
        assertTrue(dao.add(task).join());

        CountDownLatch sync = new CountDownLatch(1);
        var future = runPipeline(task.id()).whenComplete((status, throwable) -> sync.countDown());
        var context = ContextTaskHandler.contextByTaskId(task.id()).join();
        assertFalse(context.isDone());

        do {
            clock.passedTime(5, TimeUnit.MINUTES);
        } while (!sync.await(1, TimeUnit.MILLISECONDS));

        var status = future.join();
        assertEquals(status.toString(), Code.FAILED_PRECONDITION, status.getCode());
        assertTrue(context.isDone());
    }

    @Test
    public void cancelTask() {
        var task = contextTask();
        assertTrue(dao.syncAdd(task));

        CountDownLatch sync = new CountDownLatch(1);
        var future = runPipeline(task.id()).whenComplete((status, throwable) -> sync.countDown());
        var context = ContextTaskHandler.contextByTaskId(task.id()).join();

        context.cancel().join();
        var status = future.join();
        assertEquals(status.toString(), Code.OK, status.getCode());

        var result = dao.syncGet(task.id()).orElseThrow();
        assertEquals(State.SCHEDULED, result.state());
        assertEquals(List.of(), locksDao.listLocks().join());
    }

    @Test
    public void unableStartRescheduledTaskTooEarlier() {
        var task = contextTask(clock.millis() + TimeUnit.DAYS.toMillis(1));
        assertTrue(dao.syncAdd(task));

        CountDownLatch sync = new CountDownLatch(1);
        var future = runPipeline(task.id()).whenComplete((status, throwable) -> sync.countDown());
        ContextTaskHandler.contextByTaskId(task.id())
                .thenCompose(context -> context.complete(anyAsNumber(42)));

        var status = future.join();
        assertEquals(status.toString(), Code.ABORTED, status.getCode());

        var result = dao.syncGet(task.id()).orElseThrow();
        assertEquals(task, result);
    }

    private void beforeDao(Supplier<CompletableFuture<?>> supplier) {
        dao.beforeSupplier = supplier;
        locksDao.beforeSupplier = supplier;
    }

    private CompletableFuture<Status> runPipeline(String taskId) {
        return runPipeline(taskId, HostUtils.getFqdn());
    }

    private CompletableFuture<Status> runPipeline(String taskId, String host) {
        var deps = new TaskDeps(host, taskExecutor, dao, locksDao, clock, ForkJoinPool.commonPool(), timer);
        var metrics = new TaskMetrics().getByType("total");
        var pipeline = new TaskPipeline(taskId, metrics, deps);
        return pipeline.start()
                .exceptionally(Status::fromThrowable);
    }

    private int taskResultAsNumber(String taskId) {
        var task = taskById(taskId);
        assertEquals(State.COMPLETED, task.state());
        return Proto.unpack(task.result(), Int32Value.class).getValue();
    }

    private int taskProgressAsNumber(String taskId) {
        var task = taskById(taskId);
        return Proto.unpack(task.progress(), Int32Value.class).getValue();
    }

    private Task taskById(String taskId) {
        var optTask = dao.syncGet(taskId);
        assertTrue(optTask.isPresent());
        return optTask.get();
    }
}
