package ru.yandex.persqueue.read.impl.actor;

import java.util.concurrent.TimeUnit;

import com.yandex.ydb.core.Status;
import com.yandex.ydb.core.StatusCode;
import com.yandex.ydb.persqueue.YdbPersqueueV1;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadClientMessage;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadClientMessage.Commit;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadClientMessage.InitRequest;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadClientMessage.Read;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadClientMessage.Released;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadClientMessage.StartRead;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadServerMessage;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadServerMessage.Assigned;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadServerMessage.Committed;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadServerMessage.DataBatch;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadServerMessage.InitResponse;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadServerMessage.PartitionStatus;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadServerMessage.Release;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TestName;
import org.junit.rules.Timeout;

import ru.yandex.persqueue.read.impl.protocol.ClientRequestProto;
import ru.yandex.persqueue.read.impl.protocol.ServerResponseProto;
import ru.yandex.persqueue.read.impl.protocol.handler.EventHandlerStub;
import ru.yandex.persqueue.rpc.OutReadSessionObserverStub;
import ru.yandex.persqueue.rpc.PqRpcStub;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;

/**
 * @author Vladimir Gordiychuk
 */
public class ClientServerSessionTest {
    @Rule
    public Timeout timeout = Timeout.builder()
            .withTimeout(5, TimeUnit.SECONDS)
            .build();
    @Rule
    public TestName name = new TestName();
    private EventHandlerStub handler;
    private ClientServerSession session;
    private PqRpcStub rpc;
    private OutReadSessionObserverStub serverSession;

    @Before
    public void setUp() {
        handler = new EventHandlerStub();
        rpc = new PqRpcStub("test");
        session = new ClientServerSession(name.getMethodName(), rpc, () -> {}, handler);
        serverSession = rpc.getActiveReadObserver();
        assertNotNull(serverSession);
    }

    @Test
    public void sendInit() {
        testMessageSend(ClientRequestProto.init(InitRequest.newBuilder()
                .setConsumer("test/test")
                .build(), ""));
    }

    @Test
    public void sendRead() {
        testMessageSend(ClientRequestProto.read(Read.newBuilder().build(), ""));
    }

    @Test
    public void sendStartRead() {
        testMessageSend(ClientRequestProto.startRead(StartRead.newBuilder()
                .setCluster("test")
                .setAssignId(1)
                .build(), ""));
    }

    @Test
    public void sendCommit() {
        testMessageSend(ClientRequestProto.commit(Commit.newBuilder()
                .addCookies(YdbPersqueueV1.CommitCookie.newBuilder()
                        .setAssignId(123)
                        .build())
                .build(), ""));
    }

    @Test
    public void sendRelease() {
        testMessageSend(ClientRequestProto.release(Released.newBuilder()
                .setAssignId(42)
                .setCluster("myCluster")
                .build(), ""));
    }

    @Test
    public void sendStatus() {
        testMessageSend(ClientRequestProto.partitionStatus(MigrationStreamingReadClientMessage.Status.newBuilder()
                .setAssignId(22)
                .setCluster("1223")
                .build(), ""));
    }

    @Test
    public void closeSession() {
        session.close();
        serverSession.expectComplete();
    }

    @Test
    public void receiveInit() throws InterruptedException {
        testMessageReceive(ServerResponseProto.init(InitResponse.newBuilder()
                .setSessionId("mySession")
                .build()));
    }

    @Test
    public void receiveDataBatch() throws InterruptedException {
        testMessageReceive(ServerResponseProto.dataBatch(DataBatch.newBuilder()
                .addPartitionData(DataBatch.PartitionData.newBuilder()
                        .addBatches(DataBatch.Batch.newBuilder()
                                .setIp("hi")
                                .build())
                        .build())
                .build()));
    }

    @Test
    public void receiveAssign() throws InterruptedException {
        testMessageReceive(ServerResponseProto.assign(Assigned.newBuilder()
                .setAssignId(123)
                .setCluster("my")
                .build()));
    }

    @Test
    public void receiveRelease() throws InterruptedException {
        testMessageReceive(ServerResponseProto.release(Release.newBuilder()
                .setAssignId(123)
                .setCluster("my")
                .build()));
    }

    @Test
    public void receiveCommit() throws InterruptedException {
        testMessageReceive(ServerResponseProto.commit(Committed.newBuilder()
                .addCookies(YdbPersqueueV1.CommitCookie.newBuilder()
                        .setAssignId(123)
                        .build())
                .build()));
    }

    @Test
    public void receivePartitionStatus() throws InterruptedException {
        testMessageReceive(ServerResponseProto.partitionStatus(PartitionStatus.newBuilder()
                .setCluster("test")
                .build()));
    }

    @Test
    public void receiveError() throws InterruptedException {
        testMessageReceive(ServerResponseProto.failed(Status.of(StatusCode.UNAUTHORIZED)));
    }

    @Test
    public void receiveComplete() throws InterruptedException {
        replyComplete();

        var event = handler.events.take();
        assertTrue(event.toString(), event.complete);
    }

    @Test
    public void receiveTransportError() throws InterruptedException {
        var expectStatus = Status.of(StatusCode.TRANSPORT_UNAVAILABLE);
        replyError(expectStatus);

        var event = handler.events.take();
        assertEquals(event.toString(), expectStatus, event.error);
    }

    @Test
    public void initErrorByServerSide() throws InterruptedException {
        var error = Status.of(StatusCode.UNAUTHORIZED);
        reply(ServerResponseProto.failed(error));

        var handleEvent = handler.events.take();
        assertEquals(handleEvent.toString(), error, handleEvent.error);
        assertNull(serverSession.pullEvent());
    }

    @Test
    public void initErrorByClientSide() {
        var error = Status.of(StatusCode.INTERNAL_ERROR);
        session.abort(error);

        serverSession.expectError();
        replyComplete();
    }

    private void testMessageSend(MigrationStreamingReadClientMessage req) {
        session.send(req);
        assertEquals(req, serverSession.takeMessage());
        assertNull(serverSession.pullEvent());
    }

    private void reply(MigrationStreamingReadServerMessage req) {
        serverSession.observer.onNext(req);
        session.processInbound();
    }

    private void replyError(Status status) {
        serverSession.observer.onError(status);
        session.processInbound();
    }

    private void replyComplete() {
        serverSession.observer.onCompleted();
        session.processInbound();
    }

    private void testMessageReceive(MigrationStreamingReadServerMessage req) throws InterruptedException {
        reply(req);

        var event = handler.events.take();
        if (ServerResponseProto.isSuccess(req)) {
            assertNotNull(event.toString(), event.message);
            assertEquals(event.message, ServerResponseProto.getResponse(req));
        } else {
            assertEquals(event.error, ServerResponseProto.getStatus(req));
        }
    }
}
