package ru.yandex.stockpile.cluster.balancer;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.concurrent.TimeUnit;

import io.grpc.Status;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.rules.TestName;
import org.junit.rules.Timeout;
import org.junit.runner.RunWith;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;

import ru.yandex.kikimr.client.kv.KvTabletIdAndGen;
import ru.yandex.kikimr.client.kv.inMem.KikimrKvClientInMem;
import ru.yandex.solomon.locks.ReadOnlyDistributedLockStub;
import ru.yandex.solomon.ut.ManualClock;
import ru.yandex.solomon.util.file.SimpleFileStorage;
import ru.yandex.stockpile.internal.api.TAssignShardRequest;
import ru.yandex.stockpile.internal.api.TPingRequest;
import ru.yandex.stockpile.internal.api.TShardAssignment;
import ru.yandex.stockpile.internal.api.TUnassignShardRequest;
import ru.yandex.stockpile.server.shard.StockpileLocalShards;
import ru.yandex.stockpile.server.shard.StockpileShard;
import ru.yandex.stockpile.server.shard.StockpileShardGlobals;
import ru.yandex.stockpile.server.shard.test.StockpileShardTestContext;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;

/**
 * @author Vladimir Gordiychuk
 */
// TODO: drop spring from tests (gordiychuk@)
@RunWith(SpringJUnit4ClassRunner.class)
@ContextConfiguration(classes = {
    StockpileShardTestContext.class
})
@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_EACH_TEST_METHOD)
public class StockpileLocalShardsStateTest {
    private static final Logger logger = LoggerFactory.getLogger(StockpileLocalShardsStateTest.class);

    @Rule
    public TestName testName = new TestName();
    @Rule
    public Timeout globalTimeout = Timeout.builder()
        .withLookingForStuckThread(true)
        .withTimeout(1, TimeUnit.MINUTES)
        .build();
    @Rule
    public TemporaryFolder tmp = new TemporaryFolder();

    private Path storage;
    @Autowired
    private KikimrKvClientInMem kikimrKvClientInMem;
    @Autowired
    private StockpileShardGlobals shardGlobals;
    private ManualClock clock;
    private StockpileLocalShards shards;
    private ReadOnlyDistributedLockStub lock;
    private StockpileLocalShardsState state;

    @Before
    public void setUp() throws Exception {
        logger.info("Run setUp for test: {}", testName.getMethodName());

        clock = new ManualClock();
        shards = new StockpileLocalShards();
        lock = new ReadOnlyDistributedLockStub(clock);
        storage = tmp.newFolder().toPath().resolve(testName.getMethodName());
        state = new StockpileLocalShardsState(new SimpleFileStorage(storage), shards, shardGlobals, lock, clock);
    }

    @Test
    public void rejectAssignFromExpiredLeader() {
        long leaderSeqNo = 42L;
        lock.setOwner(null);

        Status status = syncAssign(makeAssignment(leaderSeqNo, 2, createTablet()));
        assertEquals(status.getDescription(), Status.Code.ABORTED, status.getCode());
    }

    @Test
    public void rejectPingFromExpiredLeader() {
        long leaderSeqNo = 42L;
        lock.setOwner(null);
        var status = state.ping(TPingRequest.newBuilder()
                .setLeaderSeqNo(leaderSeqNo)
                .setShardCount(42)
                .build())
                .thenApply(ignore -> Status.OK)
                .exceptionally(Status::fromThrowable)
                .join();

        assertEquals(status.getDescription(), Status.Code.ABORTED, status.getCode());
    }

    @Test
    public void rejectAssignBySeqNoMismatch() throws InterruptedException {
        long aliceSeqNo = lock.setOwner("alice");

        TimeUnit.NANOSECONDS.sleep(10);
        clock.passedTime(1, TimeUnit.MINUTES);
        long bobSeqNo = lock.setOwner("bob");

        Status status = syncAssign(makeAssignment(aliceSeqNo, 3, createTablet()));
        assertEquals(status.getDescription(), Status.Code.ABORTED, status.getCode());
    }

    @Test
    public void rejectAssignByDeadline() {
        long leaderSeqNo = lock.setOwner("alice");
        TAssignShardRequest request = makeAssignment(leaderSeqNo, 5, createTablet())
            .toBuilder()
            .setExpiredAt(clock.millis() - 5_000)
            .build();

        Status status = syncAssign(request);
        assertEquals(status.getDescription(), Status.Code.DEADLINE_EXCEEDED, status.getCode());
    }

    @Test
    public void successAssign() {
        long aliceSeqNo = lock.setOwner("alice");
        KvTabletIdAndGen tabletOne = createTablet();
        {
            Status status = syncAssign(makeAssignment(aliceSeqNo, 1, tabletOne));
            assertEquals(Status.Code.OK, status.getCode());

            StockpileShard shard = shards.getShardById(1);
            assertNotNull(shard);
            assertEquals(tabletOne.getTabletId(), shard.kvTabletId);

            shard.waitForInitializedOrAnyError();
            assertFalse(shard.isStop());
            assertEquals(tabletOne.getGen(), kikimrKvClientInMem.getGeneration(tabletOne.getTabletId()));
        }

        long bobSeqNo = lock.setOwner("bob");
        KvTabletIdAndGen tabletTwo = createTablet();
        {
            Status status = syncAssign(makeAssignment(bobSeqNo, 2, tabletTwo));
            assertEquals(Status.Code.OK, status.getCode());

            StockpileShard shard = shards.getShardById(2);
            assertNotNull(shard);
            assertEquals(tabletTwo.getTabletId(), shard.kvTabletId);

            shard.waitForInitializedOrAnyError();
            assertFalse(shard.isStop());
            assertEquals(tabletTwo.getGen(), kikimrKvClientInMem.getGeneration(tabletTwo.getTabletId()));
        }

        assertNotNull(shards.getShardById(1));
        assertNotNull(shards.getShardById(2));
    }

    @Test(timeout = 10_000)
    public void restoreAssignmentsOnRestart() throws InterruptedException {
        long aliceSeqNo = lock.setOwner("alice");
        KvTabletIdAndGen tabletOne = createTablet();
        {
            Status status = syncAssign(makeAssignment(aliceSeqNo, 1, tabletOne));
            assertEquals(Status.Code.OK, status.getCode());
            assertNotNull(shards.getShardById(1));
        }

        while (!Files.isReadable(storage.resolve(StockpileLocalShardsState.ASSIGNMENTS_STATE_FILE))) {
            TimeUnit.MILLISECONDS.sleep(1L);
        }

        // restart service
        shards = new StockpileLocalShards();
        state = new StockpileLocalShardsState(new SimpleFileStorage(storage), shards, shardGlobals, lock, clock);
        {
            var shard = shards.getShardById(1);
            assertNotNull(shard);
            assertEquals(shard.shardId, 1);
            assertEquals(shard.kvTabletId, tabletOne.getTabletId());
            assertEquals(shard.getGeneration(), tabletOne.getGen());
        }
    }

    @Test(timeout = 10_000)
    public void ignoreCorruptedCacheOnRestart() throws InterruptedException, IOException {
        long aliceSeqNo = lock.setOwner("alice");
        {
            Status status = syncAssign(makeAssignment(aliceSeqNo, 1, createTablet()));
            assertEquals(Status.Code.OK, status.getCode());
            assertNotNull(shards.getShardById(1));
        }

        var targetFile = storage.resolve(StockpileLocalShardsState.ASSIGNMENTS_STATE_FILE);
        while (!Files.isReadable(targetFile)) {
            TimeUnit.MILLISECONDS.sleep(1L);
        }

        Files.write(targetFile, "corrupted file for test".getBytes(), StandardOpenOption.TRUNCATE_EXISTING);

        shards = new StockpileLocalShards();
        state = new StockpileLocalShardsState(new SimpleFileStorage(storage), shards, shardGlobals, lock, clock);
        assertNull(shards.getShardById(1));
    }

    @Test
    public void rejectUnassignBySeqNoMismatch() {
        long aliceSeqNo = lock.setOwner("alice");

        KvTabletIdAndGen tablet = createTablet();
        Status assign = syncAssign(makeAssignment(aliceSeqNo, 1, tablet));
        assertEquals(Status.Code.OK, assign.getCode());

        clock.passedTime(1, TimeUnit.MINUTES);
        long bobSeqNo = lock.setOwner("bob");

        Status unassign = syncUnassign(aliceSeqNo, 1, false);
        assertEquals(Status.Code.ABORTED, unassign.getCode());

        StockpileShard shard = shards.getShardById(1);
        assertNotNull(shard);
        shard.waitForInitializedOrAnyError();
        assertFalse(shard.isStop());
    }

    @Test
    public void forceUnassign() {
        long aliceSeqNo = lock.setOwner("alice");

        KvTabletIdAndGen tablet = createTablet();
        Status assign = syncAssign(makeAssignment(aliceSeqNo, 1, tablet));
        assertEquals(Status.Code.OK, assign.getCode());

        StockpileShard shard = shards.getShardById(1);
        assertNotNull(shard);
        shard.waitForInitializedOrAnyError();

        Status unassign = syncUnassign(aliceSeqNo, 1, false);
        assertEquals(Status.Code.OK, unassign.getCode());
        assertTrue(shard.isStop());
        assertNull(shards.getShardById(1));
    }

    @Test
    public void gracefulUnassign() {
        long aliceSeqNo = lock.setOwner("alice");

        KvTabletIdAndGen tablet = createTablet();
        Status assign = syncAssign(makeAssignment(aliceSeqNo, 1, tablet));
        assertEquals(Status.Code.OK, assign.getCode());

        StockpileShard shard = shards.getShardById(1);
        assertNotNull(shard);
        shard.waitForInitializedOrAnyError();

        Status unassign = syncUnassign(aliceSeqNo, 1, true);
        assertEquals(Status.Code.OK, unassign.getCode());
        assertTrue(shard.isStop());
        assertNull(shards.getShardById(1));
    }

    @Test
    public void unassignNotExistShard() {
        long aliceSeqNo = lock.setOwner("alice");

        Status unassign = syncUnassign(aliceSeqNo, 42, true);
        assertEquals(Status.Code.OK, unassign.getCode());
        assertNull(shards.getShardById(42));
    }

    @Test
    public void shutdown() {
        long aliceSeqNo = lock.setOwner("alice");

        KvTabletIdAndGen tablet = createTablet();
        Status assign = syncAssign(makeAssignment(aliceSeqNo, 1, tablet));
        assertEquals(Status.Code.OK, assign.getCode());

        StockpileShard shard = shards.getShardById(1);
        assertNotNull(shard);
        shard.waitForInitializedOrAnyError();

        shards.gracefulShutdown().join();
        assertTrue(shard.isStop());
        assertNull(shards.getShardById(1));
    }

    @Test
    public void successPing() {
        assertEquals(shards.totalShardsCount(), 0);
        var aliceSeqNo = lock.setOwner("alice");
        {
            var response = state.ping(TPingRequest.newBuilder()
                    .setLeaderSeqNo(aliceSeqNo)
                    .setShardCount(42)
                    .build())
                    .join();

            assertEquals(42, shards.totalShardsCount());
            assertEquals(0, response.getShardSummaryCount());
        }

        {
            Status assign = syncAssign(makeAssignment(aliceSeqNo, 1, createTablet()));
            assertEquals(Status.Code.OK, assign.getCode());

            var response = state.ping(TPingRequest.newBuilder()
                    .setLeaderSeqNo(aliceSeqNo)
                    .setShardCount(42)
                    .build())
                    .join();

            assertEquals(1, response.getShardSummaryCount());
            var summary = response.getShardSummary(0);
            assertEquals(1, summary.getShardId());
        }
    }

    private Status syncAssign(TAssignShardRequest request) {
        return state.assignShard(request)
            .thenApply(ignore -> Status.OK)
            .exceptionally(Status::fromThrowable)
            .join();
    }

    private Status syncUnassign(long leaderSeqNo, int shardId, boolean graceful) {
        return state.unassignShard(TUnassignShardRequest.newBuilder()
            .setShardId(shardId)
            .setLeaderSeqNo(leaderSeqNo)
            .setGraceful(graceful)
            .build())
            .thenApply(ignore -> Status.OK)
            .exceptionally(Status::fromThrowable)
            .join();
    }

    private KvTabletIdAndGen createTablet() {
        long tabletId = kikimrKvClientInMem.createKvTablet();
        long tabletGen = kikimrKvClientInMem.incrementGeneration(tabletId, 0).join();
        return new KvTabletIdAndGen(tabletId, tabletGen);
    }

    private TAssignShardRequest makeAssignment(long leaderSeqNo, int shardId, KvTabletIdAndGen tablet) {
        return TAssignShardRequest.newBuilder()
            .setLeaderSeqNo(leaderSeqNo)
            .setAssignment(TShardAssignment.newBuilder()
                .setShardId(shardId)
                .setTabletId(tablet.getTabletId())
                .setTabletGeneration(tablet.getGen())
                .build())
            .build();
    }
}
