package ru.yandex.solomon.balancer.remote;

import java.util.Set;
import java.util.concurrent.CompletableFuture;

import io.grpc.MethodDescriptor;

import ru.yandex.cluster.discovery.ClusterDiscovery;
import ru.yandex.grpc.utils.GrpcTransport;
import ru.yandex.solomon.balancer.AssignmentSeqNo;
import ru.yandex.solomon.balancer.BalancerProto;
import ru.yandex.solomon.balancer.TAssignShardRequest;
import ru.yandex.solomon.balancer.TBalancerServiceGrpc;
import ru.yandex.solomon.balancer.TPingRequest;
import ru.yandex.solomon.balancer.TUnassignShardRequest;
import ru.yandex.solomon.balancer.TotalShardCounter;

/**
 * @author Vladimir Gordiychuk
 */
public class RemoteNodeClientImpl implements RemoteNodeClient {
    private final ClusterDiscovery<GrpcTransport> discovery;
    private final TotalShardCounter shardCounter;

    public RemoteNodeClientImpl(ClusterDiscovery<GrpcTransport> discovery, TotalShardCounter shardCounter) {
        this.discovery = discovery;
        this.shardCounter = shardCounter;
    }

    @Override
    public Set<String> getNodes() {
        return discovery.getNodes();
    }

    @Override
    public boolean hasNode(String node) {
        return discovery.hasNode(node);
    }

    @Override
    public CompletableFuture<Void> assignShard(String address, String shardId, AssignmentSeqNo seqNo, long expiredAt) {
        var req = TAssignShardRequest.newBuilder()
                .setShardId(shardId)
                .setAssignmentSeqNo(BalancerProto.toProto(seqNo))
                .setExpiredAt(expiredAt)
                .build();

        return unaryCall(address, TBalancerServiceGrpc.getAssignShardMethod(), req, expiredAt)
                .thenRun(() -> {});
    }

    @Override
    public CompletableFuture<Void> unassignShard(String address, String shardId, AssignmentSeqNo seqNo, boolean graceful, long expiredAt) {
        var req = TUnassignShardRequest.newBuilder()
                .setShardId(shardId)
                .setAssignmentSeqNo(BalancerProto.toProto(seqNo))
                .setGraceful(graceful)
                .setExpiredAt(expiredAt)
                .build();

        return unaryCall(address, TBalancerServiceGrpc.getUnassignShardMethod(), req, expiredAt)
                .thenRun(() -> {});
    }

    @Override
    public CompletableFuture<RemoteNodeState> ping(String address, long leaderSeqNo, long expiredAt) {
        var req = TPingRequest.newBuilder()
                .setLatestAssignmentSeqNo(BalancerProto.toProto(new AssignmentSeqNo(leaderSeqNo, 0)))
                .setExpiredAt(expiredAt);

        var totalShardCount = shardCounter.getTotalShardCount();
        if (totalShardCount != TotalShardCounter.SHARD_COUNT_UNKNOWN) {
            req.setTotalShardCountKnown(true).setTotalShardCount(totalShardCount);
        }

        return unaryCall(address, TBalancerServiceGrpc.getPingMethod(), req.build(), expiredAt)
                .thenApply(BalancerProto::fromProto);
    }

    private <ReqT, RespT> CompletableFuture<RespT> unaryCall(String target, MethodDescriptor<ReqT, RespT> method, ReqT request, long expiredAt) {
        try {
            GrpcTransport transport = discovery.getTransportByNode(target);
            return transport.unaryCall(method, request, expiredAt);
        } catch (Throwable e) {
            return CompletableFuture.failedFuture(e);
        }
    }
}
