package ru.yandex.solomon.gateway.tasks;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.function.BooleanSupplier;

import com.google.common.collect.ImmutableList;
import io.grpc.Status;
import io.grpc.Status.Code;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;

import ru.yandex.gateway.api.task.RemoteTaskProgress;
import ru.yandex.solomon.scheduler.proto.Task;
import ru.yandex.solomon.scheduler.proto.Task.State;
import ru.yandex.solomon.ut.ManualClock;
import ru.yandex.solomon.ut.ManualScheduledExecutorService;
import ru.yandex.solomon.util.future.RetryConfig;

import static java.util.concurrent.CompletableFuture.failedFuture;
import static java.util.stream.Collectors.collectingAndThen;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toUnmodifiableList;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals;

/**
 * @author Vladimir Gordiychuk
 */
public class ScatterGatherTest {
    @Rule
    public Timeout globalTimeout = Timeout.builder()
            .withTimeout(1, TimeUnit.MINUTES)
            .withLookingForStuckThread(true)
            .build();

    private ManualClock clock;
    private ManualScheduledExecutorService timer;
    private Map<String, RemoteTaskClientStub> clientByClusterId;

    @Before
    public void setUp() throws Exception {
        var retryConfig = RetryConfig.DEFAULT
            .withNumRetries(Integer.MAX_VALUE)
            .withMaxDelay(0);
        clock = new ManualClock();
        timer = new ManualScheduledExecutorService(1, clock);

        var sasClient = new RemoteTaskClientStub(retryConfig);
        var vlaClient = new RemoteTaskClientStub(retryConfig);
        clientByClusterId = Map.of("sas", sasClient, "vla", vlaClient);
    }

    @After
    public void tearDown() {
        if (timer != null) {
            timer.shutdownNow();
        }
    }

    @Test
    public void alreadyDone() {
        var progress = clientByClusterId.keySet()
                .stream()
                .map(clusterId -> RemoteTaskProgress.newBuilder()
                        .setComplete(true)
                        .setClusterId(clusterId)
                        .build())
                .collect(toList());

        clientByClusterId.values().forEach(
            client -> client.beforeSupplier = () -> failedFuture(Status.UNAVAILABLE.asRuntimeException())
        );

        var proc = scatterGather(progress);
        proc.start(false).join();
        assertEquals(progress, proc.progress());
    }

    @Test
    public void doneOnlyWhenAllSubTasksDone() throws InterruptedException {
        var proc = scatterGather();
        var future = proc.start(false);

        awaitScheduleTask(proc);
        assertFalse(future.isDone());

        var scheduled = proc.progress();
        var one = scheduled.get(0);
        clientByClusterId.get(one.getClusterId()).putTask(one.getRemoteTask().toBuilder()
                .setState(State.COMPLETED)
                .build());
        awaitTaskUpdate(proc, scheduled);

        var oneDone = proc.progress();
        assertNotEquals(0, oneDone.get(0).getRemoteTaskCompletedAt());
        assertFalse(future.isDone());

        var two = oneDone.get(1);
        clientByClusterId.get(two.getClusterId()).putTask(two.getRemoteTask().toBuilder()
                .setState(State.COMPLETED)
                .build());
        awaitTaskUpdate(proc, oneDone);

        var twoDone = proc.progress();
        assertNotEquals(0, twoDone.get(1).getRemoteTaskCompletedAt());
        future.join();
    }

    @Test
    public void completeWhenAllSubTasksIdle() throws InterruptedException {
        var proc = scatterGather();
        var future = proc.start(false);

        awaitScheduleTask(proc);
        assertFalse(future.isDone());

        var scheduled = proc.progress();
        var one = scheduled.get(0);
        var expectOne = one.toBuilder()
                .setRemoteTask(one.getRemoteTask().toBuilder()
                        .setState(State.SCHEDULED)
                        .setExecuteAt(System.currentTimeMillis() + TimeUnit.HOURS.toMillis(1))
                        .build())
                .build();

        clientByClusterId.get(expectOne.getClusterId()).putTask(expectOne.getRemoteTask());
        awaitTaskUpdate(proc, scheduled);

        var oneDone = proc.progress();
        var expectedOneDone = new ArrayList<>(oneDone);
        expectedOneDone.set(0, expectOne);
        assertEquals(expectedOneDone, oneDone);
        assertFalse(future.isDone());

        var two = oneDone.get(1);
        var expectTwo = two.toBuilder()
                .setRemoteTask(two.getRemoteTask().toBuilder()
                        .setState(State.SCHEDULED)
                        .setExecuteAt(System.currentTimeMillis() + TimeUnit.HOURS.toMillis(1))
                        .build())
                .build();

        clientByClusterId.get(expectTwo.getClusterId()).putTask(expectTwo.getRemoteTask());
        awaitTaskUpdate(proc, oneDone);

        var twoDone = proc.progress();
        var expectedTwoDone = List.of(expectOne, expectTwo);
        assertEquals(expectedTwoDone, twoDone);

        future.join();
        assertEquals(expectedTwoDone, proc.progress());
    }

    @Test
    public void cancelRunningTask() throws InterruptedException {
        var proc = scatterGather();
        var future = proc.start(false);

        awaitScheduleTask(proc);
        assertFalse(future.isDone());

        proc.close();

        var status = future.thenApply(ignore -> Status.OK).exceptionally(Status::fromThrowable).join();
        assertEquals(status.toString(), Code.CANCELLED, status.getCode());
    }

    private void awaitScheduleTask(ScatterGather<RemoteTaskProgress> proc) throws InterruptedException {
        awaitProgress(() -> proc.progress().stream().noneMatch(r -> r.getRemoteTaskId().isEmpty() || r.getRemoteTask().equals(Task.getDefaultInstance())));
    }

    private void awaitTaskUpdate(ScatterGather<RemoteTaskProgress> proc, List<RemoteTaskProgress> prev) throws InterruptedException {
        awaitProgress(() -> !proc.progress().equals(prev));
    }

    private void awaitProgress(BooleanSupplier interrupted) throws InterruptedException {
        while (!interrupted.getAsBoolean()) {
            awaitRemoteTaskClientCall(interrupted);
        }
    }

    private void awaitRemoteTaskClientCall(BooleanSupplier interrupted) throws InterruptedException {
        CountDownLatch sync = new CountDownLatch(1);
        ImmutableList.copyOf(clientByClusterId.values()).reverse().forEach(client -> client.beforeSupplier = () -> {
            sync.countDown();
            return CompletableFuture.completedFuture(null);
        });

        while (!sync.await(1, TimeUnit.NANOSECONDS)) {
            clock.passedTime(10, TimeUnit.SECONDS);
            if (interrupted.getAsBoolean()) {
                return;
            }
        }
    }

    private ScatterGather<RemoteTaskProgress> scatterGather() {
        return scatterGather(
            clientByClusterId.keySet().stream()
                .map(clusterId -> RemoteTaskProgress.newBuilder().setClusterId(clusterId).build())
                .collect(toUnmodifiableList()));
    }

    private ScatterGather<RemoteTaskProgress> scatterGather(List<RemoteTaskProgress> progress) {
        return progress.stream()
            .map(pr -> new RemoteTask(
                "test",
                ForkJoinPool.commonPool(),
                timer,
                pr,
                clientByClusterId.get(pr.getClusterId()),
                p -> p.getComplete() || p.getRemoteTaskCompletedAt() > 0 && p.getRemoteTaskRemovedAt() > 0
            ))
            .collect(collectingAndThen(toList(), ScatterGather::new));
    }
}
