package ru.yandex.persqueue.read.impl.actor;

import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;

import com.yandex.ydb.core.Issue;
import com.yandex.ydb.core.Result;
import com.yandex.ydb.core.Status;
import com.yandex.ydb.core.StatusCode;
import com.yandex.ydb.discovery.DiscoveryProtos.EndpointInfo;
import com.yandex.ydb.persqueue.YdbPersqueueV1;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadClientMessage.Commit;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadClientMessage.InitRequest;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadClientMessage.Read;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadClientMessage.Released;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadClientMessage.StartRead;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadClientMessage.TopicReadSettings;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadServerMessage.Assigned;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadServerMessage.Committed;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadServerMessage.DataBatch;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadServerMessage.InitResponse;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadServerMessage.PartitionStatus;
import com.yandex.ydb.persqueue.YdbPersqueueV1.MigrationStreamingReadServerMessage.Release;
import com.yandex.ydb.persqueue.YdbPersqueueV1.Path;
import com.yandex.ydb.persqueue.YdbPersqueueV1.ReadParams;

import ru.yandex.misc.actor.ActorRunner;
import ru.yandex.persqueue.read.PartitionStreamKey;
import ru.yandex.persqueue.read.impl.PartitionStreamImpl;
import ru.yandex.persqueue.read.impl.PartitionStreamMap;
import ru.yandex.persqueue.read.impl.actor.ActorEvents.Connect;
import ru.yandex.persqueue.read.impl.actor.ActorEvents.OnNodesDiscovered;
import ru.yandex.persqueue.read.impl.event.CommitAcknowledgementEventImpl;
import ru.yandex.persqueue.read.impl.event.PartitionStreamClosedEventImpl;
import ru.yandex.persqueue.read.impl.event.PartitionStreamCreateEventImpl;
import ru.yandex.persqueue.read.impl.event.PartitionStreamDestroyEventImpl;
import ru.yandex.persqueue.read.impl.event.PartitionStreamStatusEventImpl;
import ru.yandex.persqueue.read.impl.protocol.ClientRequestProto;
import ru.yandex.persqueue.read.impl.protocol.ServerResponseProtoConverter;
import ru.yandex.persqueue.read.impl.protocol.handler.TransportEventHandler;
import ru.yandex.persqueue.read.settings.ReadSessionSettings;
import ru.yandex.persqueue.rpc.PqRpc;
import ru.yandex.persqueue.rpc.RpcPool;

/**
 * @author Vladimir Gordiychuk
 */
public class ReadSessionActorImpl implements ReadSessionActor, ActorEventHandler, TransportEventHandler, AutoCloseable {
    public static final int MAX_BATCH_SIZE = 1 << 20; // 1 MiB
    private static final int MAX_READ_INFLIGHT = 6;

    private final String cluster;
    private final String endpoint;
    private final RpcPool rpcPool;
    private final ReadSessionSettings settings;
    private final ActorRunner actor;
    private final EventQueue<ActorEvent> actorEvents;
    private final ReadSessionRetryContext retryContext;
    private final EventPublisher publisher;
    private final PartitionStreamMap partitionStreamMap = new PartitionStreamMap();

    private PqRpc clusterRpc;
    private PqRpc nodeRpc;
    private ClientServerSession session;

    private int readInFlight = 0;
    private String sendToken = "";

    public ReadSessionActorImpl(String cluster, String endpoint, ReadSessionSettings settings, EventPublisher publisher, RpcPool rpcPool) {
        this.cluster = cluster;
        this.endpoint = endpoint;
        this.settings = settings;
        this.rpcPool = rpcPool;
        this.actor = new ActorRunner(this::act, settings.executor);
        this.actorEvents = new EventQueueImpl<>(actor::schedule);
        this.retryContext = new ReadSessionRetryContext(settings.retry, () -> send(new Connect()));
        this.publisher = publisher;
    }

    private void act() {
        processTransportEvents();
        processActorEvents();
    }

    private void processTransportEvents() {
        if (session != null) {
            session.processInbound();
        }
    }

    private void processActorEvents() {
        ActorEvent event;
        while ((event = actorEvents.dequeue()) != null) {
            event.dispatch(this);
        }
    }

    @Override
    public void onConnect(ActorEvents.Connect event) {
        if (session != null) {
            // already connected, skip reconnect
            return;
        }

        if (clusterRpc == null) {
            clusterRpc = rpcPool.getRpc(endpoint);
        }

        clusterRpc.discoverNodes()
                .whenComplete((r, e) -> {
                    if (e != null) {
                        r = Result.error(e);
                    }

                    send(new OnNodesDiscovered(r.map(list -> list.getEndpointsList()
                            .stream()
                            .map(EndpointInfo::getAddress)
                            .collect(Collectors.toList()))));
                });
    }

    @Override
    public void onNodeDiscovered(OnNodesDiscovered event) {
        if (session != null) {
            return;
        }

        if (!event.result.isSuccess()) {
            onError(event.result.toStatus());
            return;
        }

        var nodes = event.result.expect("onNodeDiscovered()");
        if (nodes.isEmpty()) {
            onError(Status.of(StatusCode.UNAVAILABLE,
                    Issue.of("Unable resolve host to read for endpoint " + endpoint, Issue.Severity.ERROR)));
            return;
        }

        int idx = ThreadLocalRandom.current().nextInt(nodes.size());
        nodeRpc = rpcPool.getRpc(nodes.get(idx));
        session = new ClientServerSession(endpoint, nodeRpc, actor::schedule, this);
        session.send(ClientRequestProto.init(prepareInitReq(), nextToken()));
    }

    private String nextToken() {
        var token = nodeRpc.nextToken();
        if (token.equals(sendToken)) {
            return "";
        }
        sendToken = token;
        return token;
    }

    private InitRequest prepareInitReq() {
        var request = InitRequest.newBuilder()
                .setConsumer(settings.consumerName)
                .setReadOnlyOriginal(settings.readOriginal)
                .setMaxLagDurationMs(settings.maxTimeLag.toMillis())
                .setStartFromWrittenAtMs(settings.startFrom.toEpochMilli())
                .setReadParams(ReadParams.newBuilder()
                        .setMaxReadSize(MAX_BATCH_SIZE) // Max 1 MiB
                        .build());

        for (var topic : settings.topics) {
            request.addTopicsReadSettings(TopicReadSettings.newBuilder()
                    .setTopic(topic.path)
                    .setStartFromWrittenAtMs(topic.startMessageTimestamp.toEpochMilli())
                    .addAllPartitionGroupIds(topic.partitionGroupIds)
                    .build());
        }

        return request.build();
    }

    @Override
    public void onDisconnect(ActorEvents.Disconnect event) {
        closeSession();
        if (clusterRpc != null) {
            clusterRpc.close();
            clusterRpc = null;
        }
    }

    @Override
    public void onConfirmAssign(ActorEvents.ConfirmAssign event) {
        if (!partitionStreamMap.contains(event.partitionStream)) {
            return;
        }

        var req = StartRead.newBuilder()
                .setTopic(Path.newBuilder().setPath(event.partitionStream.getTopicPath()).build())
                .setCluster(event.partitionStream.getClusterName())
                .setPartition(event.partitionStream.getPartition())
                .setAssignId(event.partitionStream.getAssignId())
                .setReadOffset(event.readOffset)
                .setCommitOffset(event.commitOffset)
                .build();
        session.send(ClientRequestProto.startRead(req, nextToken()));
        continueRead();
    }

    @Override
    public void onConfirmDestroy(ActorEvents.ConfirmDestroy event) {
        if (!partitionStreamMap.remove(event.partitionStream)) {
            return;
        }

        var req = Released.newBuilder()
                .setTopic(Path.newBuilder().setPath(event.partitionStream.getTopicPath()).build())
                .setCluster(event.partitionStream.getClusterName())
                .setPartition(event.partitionStream.getPartition())
                .setAssignId(event.partitionStream.getAssignId())
                .build();
        session.send(ClientRequestProto.release(req, nextToken()));
    }

    @Override
    public void onRequestPartitionStatus(ActorEvents.RequestPartitionStatus event) {
        if (!partitionStreamMap.contains(event.partitionStream)) {
            return;
        }

        var req= YdbPersqueueV1.MigrationStreamingReadClientMessage.Status.newBuilder()
                .setTopic(Path.newBuilder().setPath(event.partitionStream.getTopicPath()).build())
                .setCluster(event.partitionStream.getClusterName())
                .setPartition(event.partitionStream.getPartition())
                .setAssignId(event.partitionStream.getAssignId())
                .build();
        session.send(ClientRequestProto.partitionStatus(req, nextToken()));
    }

    @Override
    public void onCommit(ActorEvents.Commit event) {
        if (!partitionStreamMap.contains(event.partitionStream)) {
            return;
        }

        var req = Commit.newBuilder()
                .addCookies(YdbPersqueueV1.CommitCookie.newBuilder()
                        .setPartitionCookie(event.cookie.id)
                        .setAssignId(event.partitionStream.getAssignId())
                        .build())
                .build();
        session.send(ClientRequestProto.commit(req, nextToken()));
    }

    @Override
    public void onMemoryChunkConsumed(ActorEvents.MemoryChunkConsumed event) {
        continueRead();
    }

    @Override
    public boolean onInit(InitResponse init) {
        session.setSessionId(init.getSessionId());
        retryContext.success();
        clusterRpc.close();
        clusterRpc = null;
        return true;
    }

    @Override
    public boolean onAssign(Assigned assigned) {
        var partition = new PartitionStreamImpl(this, assigned);
        var prev = partitionStreamMap.add(partition);
        if (prev != null) {
            publisher.submit(new PartitionStreamClosedEventImpl(prev));
        }
        publisher.submit(new PartitionStreamCreateEventImpl(partition));
        return true;
    }

    @Override
    public boolean onDataBatch(DataBatch dataBatch) {
        for (var partitionData : dataBatch.getPartitionDataList()) {
            var cookie = partitionData.getCookie();
            var partitionStream = partitionStreamMap.get(cookie.getAssignId());
            if (partitionStream == null) {
                var key = new PartitionStreamKey(partitionData.getTopic().getPath(), partitionData.getCluster(), partitionData.getPartition());
                if (partitionStreamMap.get(key) != null) {
                    // Got data on previous partition assign. Ignore.
                    continue;
                }

                abortSession(Status.of(StatusCode.INTERNAL_ERROR,
                        Issue.of("Got unexpected partition stream data message. Key: " + key,
                                Issue.Severity.ERROR)));
                return false;
            }

            var messages = ServerResponseProtoConverter.toMessages(partitionStream, partitionData);
            if (messages.isEmpty()) {
                continue;
            }

            long startOffset = messages.get(0).getOffset();
            long endOffset = messages.get(messages.size() - 1).getOffset() + 1;
            partitionStream.addCookie(cookie.getPartitionCookie(), startOffset, endOffset);
            for (var message : messages) {
                publisher.submit(message);
            }
        }

        readInFlight--;
        continueRead();
        return true;
    }

    @Override
    public boolean onRelease(Release release) {
        var partitionStream = partitionStreamMap.get(release.getAssignId());
        if (partitionStream == null) {
            return true;
        }

        if (release.getForcefulRelease()) {
            partitionStream.close();
            partitionStreamMap.remove(partitionStream);
            publisher.submit(new PartitionStreamClosedEventImpl(partitionStream));
        } else {
            publisher.submit(new PartitionStreamDestroyEventImpl(partitionStream, release.getCommitOffset()));
        }
        return true;
    }

    @Override
    public boolean onCommit(Committed committed) {
        for (var cookie : committed.getCookiesList()) {
            var partitionStream = partitionStreamMap.get(cookie.getAssignId());
            if (partitionStream == null) {
                continue;
            }

            var detail = partitionStream.committed(cookie.getPartitionCookie());
            if (detail == null) {
                continue;
            }

            publisher.submit(new CommitAcknowledgementEventImpl(partitionStream, detail.endOffset - 1));
        }
        return true;
    }

    @Override
    public boolean onPartitionStatus(PartitionStatus status) {
        var partitionStream = partitionStreamMap.get(status.getAssignId());
        if (partitionStream == null) {
            return true;
        }

        publisher.submit(new PartitionStreamStatusEventImpl(partitionStream, status));
        return true;
    }

    @Override
    public void onError(Status status) {
        closeSession();
        if (!retryContext.scheduleRetry(status.getCode())) {
            publisher.closeExceptionally(status);
        }
    }

    @Override
    public void onComplete() {
        closeSession();
        publisher.close();
    }

    @Override
    public void send(ActorEvent event) {
        actorEvents.enqueue(event);
    }

    @Override
    public void close() {
        publisher.close();
        actorEvents.enqueue(new ActorEvents.Disconnect());
    }

    private void closeSession() {
        if (nodeRpc != null) {
            nodeRpc.close();
            nodeRpc = null;
        }

        if (session == null) {
            return;
        }

        session.close();
        session = null;
        readInFlight = 0;
        sendToken = "";
        for (var partitionStream : partitionStreamMap.partitionStreams()) {
            partitionStream.close();
            publisher.submit(new PartitionStreamClosedEventImpl(partitionStream));
        }
        partitionStreamMap.clear();
    }

    private void abortSession(Status status) {
        session.abort(status);
        session = null;
        for (var partitionStream : partitionStreamMap.partitionStreams()) {
            partitionStream.close();
            publisher.submit(new PartitionStreamClosedEventImpl(partitionStream));
        }
        partitionStreamMap.clear();
    }

    private void continueRead() {
        if (session == null) {
            return;
        }

        if (partitionStreamMap.isEmpty()) {
            // no active assignments to read
            return;
        }

        if (readInFlight > MAX_READ_INFLIGHT) {
            return;
        }

        int memoryInFlight = (readInFlight + 1) * MAX_BATCH_SIZE;
        int memoryUse = publisher.getMemoryUseBytes(cluster) + memoryInFlight;
        if (memoryUse >= settings.maxMemoryUsageBytes) {
            // not enough memory to read more chunks
            return;
        }

        readInFlight++;
        session.send(ClientRequestProto.read(Read.getDefaultInstance(), nextToken()));
    }
}
