package ru.yandex.solomon.gateway.tasks;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.function.BooleanSupplier;
import java.util.function.Predicate;

import com.google.protobuf.StringValue;
import io.grpc.Status;
import io.grpc.Status.Code;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;

import ru.yandex.coremon.api.task.RemoveShardResult;
import ru.yandex.gateway.api.task.RemoteTaskProgress;
import ru.yandex.solomon.scheduler.proto.Task;
import ru.yandex.solomon.scheduler.proto.Task.State;
import ru.yandex.solomon.ut.ManualClock;
import ru.yandex.solomon.ut.ManualScheduledExecutorService;
import ru.yandex.solomon.util.Proto;
import ru.yandex.solomon.util.future.RetryConfig;

import static java.util.concurrent.CompletableFuture.completedFuture;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;

/**
 * @author Vladimir Gordiychuk
 */
public class RemoteTaskTest {

    @Rule
    public Timeout globalTimeout = Timeout.builder()
            .withTimeout(1, TimeUnit.MINUTES)
            .withLookingForStuckThread(true)
            .build();

    private ManualClock clock;
    private ManualScheduledExecutorService timer;
    private RemoteTaskClientStub remoteTaskClient;

    @Before
    public void setUp() throws Exception {
        var retryConfig = RetryConfig.DEFAULT
                .withNumRetries(Integer.MAX_VALUE)
                .withMaxDelay(0);
        clock = new ManualClock();
        timer = new ManualScheduledExecutorService(1, clock);
        remoteTaskClient = new RemoteTaskClientStub(retryConfig);
    }

    @After
    public void tearDown() {
        if (timer != null) {
            timer.shutdownNow();
        }
    }

    @Test
    public void remoteTaskKnownAsDone() {
        var progress = RemoteTaskProgress.newBuilder()
                .setComplete(true)
                .setClusterId("test")
                .setRemoteTaskCompletedAt(System.currentTimeMillis())
                .setRemoteTaskRemovedAt(System.currentTimeMillis())
                .build();

        remoteTaskClient.beforeSupplier = () -> {
            throw Status.UNAVAILABLE.asRuntimeException();
        };

        var proc = remoteTask(progress);
        proc.start(false).join();
        assertEquals(progress, proc.progress());
    }

    @Test
    public void remoteTaskAlreadyDone() {
        var task = remoteTaskClient.prepareTask();

        var progress = RemoteTaskProgress.newBuilder()
                .setRemoteTaskId(task.getId())
                .setClusterId("test")
                .setRemoteTask(task)
                .build();

        var expectedTask = progress.getRemoteTask().toBuilder()
                .setState(State.COMPLETED)
                .setExecuteAt(System.currentTimeMillis() - 100L)
                .build();
        remoteTaskClient.putTask(expectedTask);

        var proc = remoteTask(progress);
        proc.start(false).join();

        var expectProgress = progress.toBuilder()
                .setRemoteTask(expectedTask)
                .setRemoteTaskCompletedAt(proc.progress().getRemoteTaskCompletedAt())
                .build();
        assertEquals(expectProgress, proc.progress());
        assertNotEquals(0, proc.progress().getRemoteTaskCompletedAt());
    }

    @Test
    public void remoteTaskAlreadyRemoved() {
        var task = remoteTaskClient.prepareTask();
        var progress = RemoteTaskProgress.newBuilder()
                .setRemoteTaskId(task.getId())
                .setClusterId("test")
                .setRemoteTask(task)
                .build();

        remoteTaskClient.putTask(progress.getRemoteTask().toBuilder()
                .setState(State.COMPLETED)
                .setExecuteAt(System.currentTimeMillis() - 100L)
                .build());

        var procOne = remoteTask(progress);
        procOne.start(false).join();
        assertNotEquals(0, procOne.progress().getRemoteTaskCompletedAt());
        assertEquals(0, procOne.progress().getRemoteTaskRemovedAt());

        var procTwo = remoteTask(procOne.progress());
        procTwo.start(false).join();
        assertEquals(procOne.progress(), procTwo.progress());

        remoteTaskClient.removeTaskById(task.getId());
        var procTree = remoteTask(procTwo.progress());
        procTree.start(false).join();

        assertNotEquals(0, procTree.progress().getRemoteTaskCompletedAt());
        assertNotEquals(0, procTree.progress().getRemoteTaskRemovedAt());
        assertTrue(procTree.progress().getComplete());
    }

    @Test
    public void scheduleTaskAndPoll() throws InterruptedException {
        var proc = remoteTask("test");
        var future = proc.start(false);

        // schedule task
        String taskId = awaitScheduleTask(proc);
        assertNotNull(remoteTaskClient.getTaskById(taskId));
        assertFalse(proc.progress().getComplete());

        var expectInit = remoteTaskClient.getTaskById(taskId);
        var init = awaitTaskUpdate(proc, Task.getDefaultInstance());
        assertEquals(expectInit, init);
        assertFalse(proc.progress().getComplete());

        // pull update
        var expectUpdate = init.toBuilder()
                .setState(State.RUNNING)
                .setProgress(Proto.pack(StringValue.of("blah blah, progressing")))
                .build();
        remoteTaskClient.putTask(expectUpdate);
        var update = awaitTaskUpdate(proc, init);
        assertEquals(expectUpdate, update);
        assertFalse(future.isDone());
        assertFalse(proc.progress().getComplete());

        // pull complete
        var expectComplete = update.toBuilder()
                .setState(State.COMPLETED)
                .setResult(Proto.pack(RemoveShardResult.newBuilder().setRemovedMetrics(42).build()))
                .build();
        remoteTaskClient.putTask(expectComplete);
        var complete = awaitTaskUpdate(proc, update);
        assertEquals(complete, expectComplete);
        future.join();
        assertFalse(proc.progress().getComplete());
    }

    @Test
    public void scheduleStatusIsNotOk() {
        remoteTaskClient.beforeSupplier = () -> {
            throw Status.NOT_FOUND.asRuntimeException();
        };

        var proc = remoteTask("not_exist_anymore");
        var status = proc.start(false).thenApply(unused -> Status.OK).exceptionally(Status::fromThrowable).join();
        assertEquals(status.toString(), Code.NOT_FOUND, status.getCode());
        assertFalse(proc.progress().getComplete());

        remoteTaskClient.beforeSupplier = () -> {
            throw Status.UNKNOWN.asRuntimeException();
        };

        var proc2 = remoteTask("unknown_status");
        var status2 = proc2.start(false).thenApply(unused -> Status.OK).exceptionally(Status::fromThrowable).join();
        assertEquals(status2.toString(), Code.UNKNOWN, status2.getCode());
        assertFalse(proc2.progress().getComplete());
    }

    @Test
    public void taskAlreadyTtlOnNode() {
        var task = remoteTaskClient.prepareTask();
        remoteTaskClient.removeTaskById(task.getId());

        var progress = RemoteTaskProgress.getDefaultInstance()
                .toBuilder()
                .setClusterId("test")
                .setRemoteTask(task)
                .setRemoteTaskId(task.getId())
                .setRemoteTaskCompletedAt(System.currentTimeMillis())
                .build();

        var proc = remoteTask(progress);
        proc.start(false).join();
        assertTrue(proc.progress().getComplete());
    }

    @Test
    public void cancelRunningTask() throws InterruptedException {
        var proc = remoteTask("test");
        var future = proc.start(false);

        // schedule task
        String taskId = awaitScheduleTask(proc);
        assertNotNull(remoteTaskClient.getTaskById(taskId));
        assertFalse(proc.progress().getComplete());

        proc.close();

        var status = future.thenApply(unused -> Status.OK).exceptionally(Status::fromThrowable).join();
        assertEquals(status.toString(), Code.CANCELLED, status.getCode());
        assertFalse(proc.progress().getComplete());
    }

    @Test
    public void scheduleAndInterrupt() throws InterruptedException {
        // schedule task
        var scheduleProc = remoteTaskNoRemove("test");
        var future = scheduleProc.start(false);

        String taskId = awaitScheduleTask(scheduleProc);
        assertNotNull(remoteTaskClient.getTaskById(taskId));
        assertFalse(scheduleProc.progress().getComplete());
        awaitTaskUpdate(scheduleProc, Task.getDefaultInstance());
        scheduleProc.close();
        future.thenApply(unused -> Status.OK).exceptionally(Status::fromThrowable).join();

        // interrupt task (using progress)
        var interruptProc = remoteTaskNoRemove(scheduleProc.progress());
        remoteTaskClient.interruptId = taskId;
        var interruptFuture = interruptProc.start(true);

        var before = interruptProc.progress().getRemoteTask();
        var expectUpdate = before.toBuilder()
            .setState(State.COMPLETED)
            .setProgress(Proto.pack(StringValue.of("interrupted")))
            .build();
        remoteTaskClient.putTask(expectUpdate);
        awaitTaskUpdate(interruptProc, before);

        interruptFuture.join();
        var interruptProgress = interruptProc.progress();
        assertTrue(interruptProgress.getComplete());
        assertTrue(interruptProgress.getInterrupted());
        assertEquals(expectUpdate, interruptProgress.getRemoteTask());
    }

    @Test
    public void scheduleAndInterruptFromScratch() throws InterruptedException {
        // schedule task
        var scheduleProc = remoteTaskNoRemove("test");
        var future = scheduleProc.start(false);

        String taskId = awaitScheduleTask(scheduleProc);
        assertNotNull(remoteTaskClient.getTaskById(taskId));
        assertFalse(scheduleProc.progress().getComplete());
        awaitTaskUpdate(scheduleProc, Task.getDefaultInstance());
        scheduleProc.close();
        future.thenApply(unused -> Status.OK).exceptionally(Status::fromThrowable).join();

        // interrupt task (from scratch)
        var interruptProc = remoteTaskNoRemove("test");
        remoteTaskClient.interruptId = taskId;
        var interruptFuture = interruptProc.start(true);

        var before = scheduleProc.progress().getRemoteTask();
        var expectUpdate = before.toBuilder()
            .setState(State.COMPLETED)
            .setProgress(Proto.pack(StringValue.of("interrupted")))
            .build();
        remoteTaskClient.putTask(expectUpdate);
        awaitTaskUpdate(interruptProc, before);

        interruptFuture.join();
        var interruptProgress = interruptProc.progress();
        assertTrue(interruptProgress.getComplete());
        assertTrue(interruptProgress.getInterrupted());
        assertEquals(expectUpdate, interruptProgress.getRemoteTask());
    }

    @Test
    public void interruptUnknown() {
        var interruptTask = remoteTask("test");
        remoteTaskClient.interruptId = "42";

        interruptTask.start(true).join();
        var progress = interruptTask.progress();
        assertTrue(progress.getComplete());
        assertTrue(progress.getInterrupted());
        assertEquals("test", progress.getClusterId());
        assertFalse(progress.hasRemoteTask());
    }

    private String awaitScheduleTask(RemoteTask proc) throws InterruptedException {
        awaitProgress(() -> !proc.progress().getRemoteTaskId().isEmpty());
        return proc.progress().getRemoteTaskId();
    }

    private Task awaitTaskUpdate(RemoteTask proc, Task prev) throws InterruptedException {
        awaitProgress(() -> !proc.progress().getRemoteTask().equals(prev));
        return proc.progress().getRemoteTask();
    }

    private void awaitProgress(BooleanSupplier interrupted) throws InterruptedException {
        while (!interrupted.getAsBoolean()) {
            awaitRemoteTaskClientCall(interrupted);
        }
    }

    private void awaitRemoteTaskClientCall(BooleanSupplier interrupted) throws InterruptedException {
        CountDownLatch sync = new CountDownLatch(1);
        remoteTaskClient.beforeSupplier = () -> {
            sync.countDown();
            return completedFuture(null);
        };

        while (!sync.await(1, TimeUnit.NANOSECONDS)) {
            clock.passedTime(10, TimeUnit.SECONDS);
            if (interrupted.getAsBoolean()) {
                return;
            }
        }
    }

    private RemoteTask remoteTask(String clusterId) {
        return remoteTask(initProgress(clusterId));
    }

    private RemoteTask remoteTask(RemoteTaskProgress progress) {
        return remoteTask(
            progress,
            p -> p.getComplete() || p.getRemoteTaskCompletedAt() > 0 && p.getRemoteTaskRemovedAt() > 0);
    }

    private RemoteTask remoteTaskNoRemove(String clusterId) {
        return remoteTaskNoRemove(initProgress(clusterId));
    }

    private RemoteTask remoteTaskNoRemove(RemoteTaskProgress progress) {
        return remoteTask(
            progress,
            p -> p.getComplete() || p.getRemoteTaskCompletedAt() > 0);
    }

    private RemoteTask remoteTask(RemoteTaskProgress progress, Predicate<RemoteTaskProgress> isComplete) {
        return new RemoteTask(
            "test",
            ForkJoinPool.commonPool(),
            timer,
            progress,
            remoteTaskClient,
            isComplete
        );
    }

    private RemoteTaskProgress initProgress(String clusterId) {
        return RemoteTaskProgress.newBuilder()
                .setClusterId(clusterId)
                .build();
    }
}
