package ru.yandex.solomon.gateway.tasks.removeShard;

import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.function.BooleanSupplier;

import com.google.protobuf.Any;
import io.grpc.Status;
import io.grpc.Status.Code;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;

import ru.yandex.coremon.api.task.RemoveShardResult;
import ru.yandex.gateway.api.task.RemoteTaskProgress;
import ru.yandex.gateway.api.task.RemoveShardParams;
import ru.yandex.gateway.api.task.RemoveShardProgress;
import ru.yandex.gateway.api.task.RemoveShardProgress.RemoveConf;
import ru.yandex.solomon.core.db.dao.memory.InMemoryShardDao;
import ru.yandex.solomon.core.db.model.Shard;
import ru.yandex.solomon.coremon.client.CoremonClientStub;
import ru.yandex.solomon.scheduler.ExecutionContext;
import ru.yandex.solomon.scheduler.ExecutionContextStub;
import ru.yandex.solomon.scheduler.ExecutionContextStub.Complete;
import ru.yandex.solomon.scheduler.ExecutionContextStub.Fail;
import ru.yandex.solomon.scheduler.ExecutionContextStub.Reschedule;
import ru.yandex.solomon.scheduler.Task;
import ru.yandex.solomon.scheduler.grpc.Proto;
import ru.yandex.solomon.scheduler.proto.Task.State;
import ru.yandex.solomon.ut.ManualClock;
import ru.yandex.solomon.ut.ManualScheduledExecutorService;
import ru.yandex.solomon.util.future.RetryConfig;

import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.core.AllOf.allOf;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;

/**
 * @author Vladimir Gordiychuk
 */
public class RemoveShardTaskTest {

    @Rule
    public Timeout globalTimeout = Timeout.builder()
            .withTimeout(10, TimeUnit.MINUTES)
            .withLookingForStuckThread(true)
            .build();

    private RetryConfig retryConfig;
    private ManualClock clock;
    private ManualScheduledExecutorService timer;
    private InMemoryShardDao dao;
    private CoremonClientStub coremonClient;

    @Before
    public void setUp() throws Exception {
        retryConfig = RetryConfig.DEFAULT
                .withNumRetries(Integer.MAX_VALUE)
                .withMaxDelay(0);
        clock = new ManualClock();
        timer = new ManualScheduledExecutorService(1, clock);
        dao = new InMemoryShardDao();
        coremonClient = new CoremonClientStub();
        coremonClient.addCluster("test");
    }

    @After
    public void tearDown() {
        if (timer != null) {
            timer.shutdownNow();
        }
    }

    @Test
    public void rescheduleIdle() throws InterruptedException {
        var shard = randomShard();
        assertTrue(dao.insert(shard).join());

        var params = params(shard);
        var context = context(params, RemoveShardProgress.getDefaultInstance());

        var proc = remove(context);
        var future = proc.start();

        awaitScheduleTask(proc);
        var progress = proc.progress();
        var taskId = progress.getRemoveReplica(0).getRemoteTaskId();
        var taskRescheduled = coremonClient.taskById("test", taskId).toBuilder()
                .setState(State.SCHEDULED)
                .setExecuteAt(System.currentTimeMillis() + TimeUnit.DAYS.toMillis(5))
                .build();

        coremonClient.putTask("test", taskRescheduled);
        awaitTaskUpdate(proc, progress);
        assertNotEquals(progress, proc.progress());
        future.join();

        var event = context.takeDoneEvent(Reschedule.class);
        assertEquals((double) taskRescheduled.getExecuteAt(), (double) event.executeAt(), 10_000d);
        assertNotEquals(new Int2ObjectOpenHashMap<>(), dao.findAllIdToShardId().join());
    }

    @Test
    public void rescheduleOnError() {
        var shard = randomShard();
        assertTrue(dao.insert(shard).join());

        retryConfig = RetryConfig.DEFAULT
                .withNumRetries(3)
                .withMaxDelay(0);

        coremonClient.beforeSupplier = () -> {
            return CompletableFuture.failedFuture(Status.ABORTED.withDescription("hi").asRuntimeException());
        };

        var params = params(shard);
        var context = context(params, RemoveShardProgress.getDefaultInstance());

        var proc = remove(context);
        var future = proc.start();
        var status = future.thenApply(unused -> Status.OK).exceptionally(Status::fromThrowable).join();
        assertEquals(status.getCode(), Code.OK);

        var event = context.takeDoneEvent(Reschedule.class);
        var delay = event.executeAt() - System.currentTimeMillis();
        assertThat(delay, allOf(
                lessThanOrEqualTo(TimeUnit.DAYS.toMillis(1)),
                greaterThanOrEqualTo(0L)));
        assertNotEquals(new Int2ObjectOpenHashMap<>(), dao.findAllIdToShardId().join());
    }

    @Test
    public void rescheduleRemoteComplete() throws InterruptedException {
        var shard = randomShard();
        assertTrue(dao.insert(shard).join());

        var params = params(shard);
        var context = context(params, RemoveShardProgress.getDefaultInstance());

        var proc = remove(context);
        var future = proc.start();

        awaitScheduleTask(proc);

        var progress = proc.progress();
        var taskId = progress.getRemoveReplica(0).getRemoteTaskId();
        var taskCompleted = coremonClient.taskById("test", taskId).toBuilder()
                .setState(State.COMPLETED)
                .build();

        coremonClient.putTask("test", taskCompleted);
        awaitTaskUpdate(proc, progress);
        assertNotEquals(progress, proc.progress());
        future.join();

        var event = context.takeDoneEvent(Reschedule.class);
        var delay = event.executeAt() - System.currentTimeMillis();
        assertEquals((double) TimeUnit.DAYS.toMillis(7), (double) delay, TimeUnit.DAYS.toMillis(1));
        assertNotEquals(new Int2ObjectOpenHashMap<>(), dao.findAllIdToShardId().join());
    }

    @Test
    public void fail() throws InterruptedException {
        var shard = randomShard();
        assertTrue(dao.insert(shard).join());

        var params = params(shard);
        var context = context(params, RemoveShardProgress.getDefaultInstance());

        var proc = remove(context);
        var future = proc.start();

        awaitScheduleTask(proc);

        var progress = proc.progress();
        var taskId = progress.getRemoveReplica(0).getRemoteTaskId();
        var taskCompleted = coremonClient.taskById("test", taskId).toBuilder()
                .setState(State.COMPLETED)
                .setStatus(Proto.toProto(Status.INTERNAL.withDescription("hi")))
                .build();

        coremonClient.putTask("test", taskCompleted);
        awaitTaskUpdate(proc, progress);
        assertNotEquals(progress, proc.progress());
        future.join();

        var event = context.takeDoneEvent(Fail.class);
        var status = Status.fromThrowable(event.throwable());
        assertEquals(status.toString(), Code.INTERNAL, status.getCode());
        assertNotEquals(new Int2ObjectOpenHashMap<>(), dao.findAllIdToShardId().join());
    }

    @Test
    public void rescheduleWhenTaskRemoved() throws InterruptedException {
        var shard = randomShard();
        assertTrue(dao.insert(shard).join());

        var removeOne = remove(context(params(shard), RemoveShardProgress.getDefaultInstance()));
        {
            var future = removeOne.start();

            awaitScheduleTask(removeOne);

            var progress = removeOne.progress();
            var taskId = progress.getRemoveReplica(0).getRemoteTaskId();
            var taskCompleted = coremonClient.taskById("test", taskId).toBuilder()
                    .setState(State.COMPLETED)
                    .build();

            coremonClient.putTask("test", taskCompleted);
            awaitTaskUpdate(removeOne, progress);
            assertNotEquals(progress, removeOne.progress());
            future.join();
        }

        var progress = removeOne.progress();
        var taskId = progress.getRemoveReplica(0).getRemoteTaskId();
        coremonClient.removeTaskById("test", taskId);

        var context = context(params(shard), removeOne.progress());
        var removeTwo = remove(context);
        removeTwo.start().join();

        var event = context.takeDoneEvent(Reschedule.class);
        var delay = event.executeAt() - System.currentTimeMillis();
        assertThat(delay, allOf(
                lessThanOrEqualTo(TimeUnit.DAYS.toMillis(60)),
                greaterThanOrEqualTo(TimeUnit.DAYS.toMillis(30))));
        assertNotEquals(new Int2ObjectOpenHashMap<>(), dao.findAllIdToShardId().join());
    }

    @Test
    public void releaseNumId() {
        var shard = randomShard();
        assertTrue(dao.insert(shard).join());
        assertTrue(dao.deleteOne(shard.getProjectId(), shard.getFolderId(), shard.getId()).join());

        var initProgress = RemoveShardProgress.newBuilder()
                .setRemoveConf(RemoveConf.newBuilder().setComplete(true).build())
                .addRemoveReplica(RemoteTaskProgress.newBuilder()
                        .setClusterId("test")
                        .setRemoteTaskId("already_removed_task")
                        .setComplete(true)
                        .setRemoteTask(ru.yandex.solomon.scheduler.proto.Task.newBuilder()
                                .setState(State.COMPLETED)
                                .setResult(Any.pack(RemoveShardResult.newBuilder()
                                                .setRemovedMetrics(10)
                                        .build()))
                                .build())
                        .setRemoteTaskCompletedAt(System.currentTimeMillis() - TimeUnit.DAYS.toMillis(60L))
                        .setRemoteTaskRemovedAt(System.currentTimeMillis() - TimeUnit.DAYS.toMillis(35L))
                        .build())
                .build();

        var context = context(params(shard), initProgress);
        var proc = remove(context);
        proc.start().join();

        var event = context.takeDoneEvent(Complete.class);
        assertNotNull(event);

        var result = RemoveShardTaskProto.result(event.result());
        assertEquals("test", result.getResults(0).getClusterId());
        assertEquals(10, result.getResults(0).getRemovedMetrics());
        assertEquals(new Int2ObjectOpenHashMap<>(), dao.findAllIdToShardId().join());
    }

    private RemoveShardTask remove(ExecutionContext context) {
        return new RemoveShardTask(retryConfig, coremonClient, dao, ForkJoinPool.commonPool(), timer, context);
    }

    private Shard randomShard() {
        var random = ThreadLocalRandom.current();
        return Shard.newBuilder()
                .setProjectId("project_id_" + random.nextLong())
                .setId("shard_id_" + random.nextLong())
                .setClusterId("cluster_id_" + random.nextLong())
                .setClusterName("cluster_name_" + random.nextLong())
                .setServiceId("service_id_" + random.nextLong())
                .setServiceName("service_name_" + random.nextLong())
                .build();
    }

    private RemoveShardParams params(Shard shard) {
        return RemoveShardParams.newBuilder()
                .setProjectId(shard.getProjectId())
                .setShardId(shard.getId())
                .setNumId(shard.getNumId())
                .build();
    }

    private ExecutionContextStub context(RemoveShardParams params, RemoveShardProgress progress) {
        var task = Task.newBuilder()
                .setId(UUID.randomUUID().toString())
                .setType("remove_shard")
                .setExecuteAt(System.currentTimeMillis())
                .setProgress(Any.pack(progress))
                .setParams(Any.pack(params))
                .build();

        return new ExecutionContextStub(task);
    }

    private void awaitScheduleTask(RemoveShardTask proc) throws InterruptedException {
        awaitProgress(() -> proc.progress().getRemoveReplicaList().stream().noneMatch(r -> r.getRemoteTaskId().isEmpty() || r.getRemoteTask().equals(ru.yandex.solomon.scheduler.proto.Task.getDefaultInstance())));
    }

    private void awaitTaskUpdate(RemoveShardTask proc, RemoveShardProgress prev) throws InterruptedException {
        awaitProgress(() -> !proc.progress().equals(prev));
    }

    private void awaitProgress(BooleanSupplier interrupted) throws InterruptedException {
        while (!interrupted.getAsBoolean()) {
            awaitCoremonCall(interrupted);
        }
    }

    private void awaitCoremonCall(BooleanSupplier interrupted) throws InterruptedException {
        CountDownLatch sync = new CountDownLatch(1);
        coremonClient.beforeSupplier = () -> {
            sync.countDown();
            return CompletableFuture.completedFuture(null);
        };

        while (!sync.await(1, TimeUnit.NANOSECONDS)) {
            clock.passedTime(10, TimeUnit.SECONDS);
            if (interrupted.getAsBoolean()) {
                return;
            }
        }
    }
}
