package ru.yandex.persqueue.read.impl.actor;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import com.google.protobuf.ByteString;
import com.yandex.ydb.core.Status;
import com.yandex.ydb.persqueue.YdbPersqueueV1;
import com.yandex.ydb.persqueue.YdbPersqueueV1.CommitCookie;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadClientMessage.Commit;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadClientMessage.InitRequest;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadClientMessage.Read;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadClientMessage.Released;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadClientMessage.StartRead;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadServerMessage.Assigned;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadServerMessage.Committed;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadServerMessage.DataBatch;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadServerMessage.DataBatch.Batch;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadServerMessage.DataBatch.MessageData;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadServerMessage.DataBatch.PartitionData;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadServerMessage.InitResponse;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadServerMessage.PartitionStatus;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadServerMessage.Release;
import org.hamcrest.Matchers;

import ru.yandex.persqueue.read.EventSubscriber;
import ru.yandex.persqueue.read.event.PartitionStreamDestroyEvent;
import ru.yandex.persqueue.read.impl.protocol.ClientRequestProto;
import ru.yandex.persqueue.read.impl.protocol.ServerResponseProto;
import ru.yandex.persqueue.rpc.OutReadSessionObserverStub;

import static org.hamcrest.CoreMatchers.instanceOf;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;

/**
 * @author Vladimir Gordiychuk
 */
public class ServerSession {
    private final OutReadSessionObserverStub observer;
    private final String name;
    private final String cluster;
    private final EventSubscriber subscriber;
    private int sessionId;
    private int assignId;
    private int expectedRead;
    private Map<Long, Assigned> assignmentById = new HashMap<>();

    public ServerSession(OutReadSessionObserverStub observer, String name, String cluster, EventSubscriber subscriber) {
        assertNotNull("On start should be initialized session on server side", observer);
        this.name = name;
        this.cluster = cluster;
        this.subscriber = subscriber;
        this.observer = observer;
        this.assignId = ThreadLocalRandom.current().nextInt();
    }

    public void expectInit() {
        InitRequest init = observer.takeRequest(InitRequest.class);
        assertEquals(name + "/consumer", init.getConsumer());
        assertEquals(2, init.getTopicsReadSettingsCount());
        assertEquals("/topic/alice", init.getTopicsReadSettings(0).getTopic());
        assertEquals("/topic/bob", init.getTopicsReadSettings(1).getTopic());
    }

    public void sendInit() {
        sessionId = ThreadLocalRandom.current().nextInt();
        observer.observer.onNext(ServerResponseProto.init(InitResponse.newBuilder()
                .setSessionId(Integer.toString(sessionId))
                .build()));
    }

    public long expectAssign(String topic, long partition) {
        var id = sendAssign(topic, partition);
        var assign = subscriber.expectAssign(id, cluster, topic, partition);
        assign.confirm();
        expectConfirmAssign(id);
        return id;
    }

    public void expectRelease(long assignId) {
        sendRelease(assignId, 42, false);
        var event = subscriber.takeEvent(PartitionStreamDestroyEvent.class);
        var partitionStream = event.getPartitionStream();
        assertEquals(assignId, partitionStream.getAssignId());
        assertEquals(42, event.getCommittedOffset());
        event.confirm();
        expectConfirmRelease(assignId);
    }

    public void expectRead() {
        if (expectedRead == 0) {
            var req = ClientRequestProto.getRequest(observer.takeMessage());
            expectedRead++;
            assertThat(req, Matchers.instanceOf(Read.class));
        }

        assertNotEquals(0, expectedRead);
    }

    public void expectNoReads() {
        while (true) {
            var event = observer.pullEvent();
            if (event == null) {
                break;
            }

            assertNotNull(event.toString(), event.message);
            var req = ClientRequestProto.getRequest(observer.takeMessage());
            if (req instanceof Read) {
                expectedRead++;
            }
        }

        assertEquals(0, expectedRead);
    }

    public void expectCommit(long assignId, long cookie) {
        Commit commit = takeRequest(Commit.class);
        assertEquals(commit.getCookies(0).getAssignId(), assignId);
        assertEquals(commit.getCookies(0).getPartitionCookie(), cookie);
    }

    public void sendCommitted(long assignId, long cookie) {
        observer.observer.onNext(ServerResponseProto.commit(Committed.newBuilder()
                .addCookies(CommitCookie.newBuilder()
                        .setAssignId(assignId)
                        .setPartitionCookie(cookie)
                        .build())
                .build()));
    }

    public void expectComplete() {
        observer.expectComplete();
    }

    public void expectConfirmAssign(long assignId) {
        StartRead confirm = takeRequest(StartRead.class);

        var assign = assignmentById.get(assignId);
        assertEquals(confirm.toString(), assign.getAssignId(), confirm.getAssignId());
        assertEquals(confirm.toString(), assign.getCluster(), confirm.getCluster());
        assertEquals(confirm.toString(), assign.getTopic(), confirm.getTopic());
        assertEquals(confirm.toString(), assign.getPartition(), confirm.getPartition());
        assertEquals(confirm.toString(), 0, confirm.getReadOffset());
        assertEquals(confirm.toString(), 0, confirm.getCommitOffset());
    }

    public void expectPartitionStatus(long assignId) {
        var status = takeRequest(YdbPersqueueV1.MigrationStreamingReadClientMessage.Status.class);
        var assign = assignmentById.get(assignId);
        assertEquals(status.toString(), assign.getAssignId(), status.getAssignId());
        assertEquals(status.toString(), assign.getCluster(), status.getCluster());
        assertEquals(status.toString(), assign.getTopic(), status.getTopic());
        assertEquals(status.toString(), assign.getPartition(), status.getPartition());
    }

    public void sendError(Status error) {
        observer.observer.onNext(ServerResponseProto.failed(error));
        observer.observer.onCompleted();
    }

    public void sendComplete() {
        observer.observer.onCompleted();
    }

    public long sendAssign(String topic, long partition) {
        var assign = Assigned.newBuilder()
                .setCluster(cluster)
                .setAssignId(assignId++)
                .setPartition(partition)
                .setTopic(YdbPersqueueV1.Path.newBuilder()
                        .setPath(topic)
                        .build())
                .setEndOffset(765)
                .build();
        assertNull(assignmentById.put(assign.getAssignId(), assign));
        observer.observer.onNext(ServerResponseProto.assign(assign));
        return assign.getAssignId();
    }

    public void sendRelease(long assignId, long committedOffset, boolean force) {
        var assign = assignmentById.get(assignId);
        assertNotNull(assign);

        var release = Release.newBuilder()
                .setAssignId(assignId)
                .setCluster(assign.getCluster())
                .setTopic(assign.getTopic())
                .setPartition(assign.getPartition())
                .setForcefulRelease(force)
                .setCommitOffset(committedOffset)
                .build();

        observer.observer.onNext(ServerResponseProto.release(release));
    }

    public long sendData(long assignId, long startOffset, Map<String, String> meta, YdbPersqueueV1.Codec codec, ByteString... data) {
        var assign = assignmentById.get(assignId);
        assertNotNull(assign);

        var cookie = ThreadLocalRandom.current().nextLong();
        var dataBatch = DataBatch.newBuilder()
                .addPartitionData(PartitionData.newBuilder()
                        .setTopic(assign.getTopic())
                        .setCluster(assign.getCluster())
                        .setPartition(assign.getPartition())
                        .setCookie(YdbPersqueueV1.CommitCookie.newBuilder()
                                .setAssignId(assignId)
                                .setPartitionCookie(cookie)
                                .build())
                        .addBatches(Batch.newBuilder()
                                .setIp("localhost")
                                .setWriteTimestampMs(System.currentTimeMillis())
                                .addAllExtraFields(meta.entrySet()
                                        .stream()
                                        .map(entry -> YdbPersqueueV1.KeyValue.newBuilder()
                                                .setKey(entry.getKey())
                                                .setValue(entry.getValue())
                                                .build())
                                        .collect(Collectors.toList()))
                                .addAllMessageData(IntStream.range(0, data.length)
                                        .mapToObj(offset -> MessageData.newBuilder()
                                                .setCodec(codec)
                                                .setOffset(startOffset + offset)
                                                .setSeqNo(offset)
                                                .setData(data[offset])
                                                .setCreateTimestampMs(System.currentTimeMillis())
                                                .build())
                                        .collect(Collectors.toList()))
                                .build())
                        .build())
                .build();

        expectedRead--;
        observer.observer.onNext(ServerResponseProto.dataBatch(dataBatch));
        return cookie;
    }

    public void sendPartitionStatus(long assignId, long committedOffset, long endOffset, long writeWatermark) {
        var assign = assignmentById.get(assignId);
        assertNotNull(assign);

        var status = PartitionStatus.newBuilder()
                .setAssignId(assignId)
                .setCluster(assign.getCluster())
                .setTopic(assign.getTopic())
                .setPartition(assign.getPartition())
                .setWriteWatermarkMs(writeWatermark)
                .setCommittedOffset(committedOffset)
                .setEndOffset(endOffset)
                .build();

        observer.observer.onNext(ServerResponseProto.partitionStatus(status));
    }

    public void expectConfirmRelease(long assignId) {
        Released confirm = takeRequest(Released.class);

        var assign = assignmentById.get(assignId);
        assertEquals(confirm.toString(), assign.getAssignId(), confirm.getAssignId());
        assertEquals(confirm.toString(), assign.getCluster(), confirm.getCluster());
        assertEquals(confirm.toString(), assign.getTopic(), confirm.getTopic());
        assertEquals(confirm.toString(), assign.getPartition(), confirm.getPartition());
    }

    public <T> T takeRequest(Class<T> clazz) {
        while (true) {
            var req = ClientRequestProto.getRequest(observer.takeMessage());
            if (req instanceof Read) {
                expectedRead++;
                continue;
            }

            assertThat(req, instanceOf(clazz));
            return clazz.cast(req);
        }
    }
}
