package ru.yandex.grpc.utils;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Flow;
import java.util.concurrent.TimeUnit;

import javax.annotation.ParametersAreNonnullByDefault;

import com.google.common.base.Strings;
import com.google.common.net.HostAndPort;
import com.google.protobuf.Message;
import io.grpc.CallOptions;
import io.grpc.ClientCall;
import io.grpc.ConnectivityState;
import io.grpc.ManagedChannel;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.StreamObserver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.misc.concurrent.CompletableFutures;
import ru.yandex.monlib.metrics.labels.Labels;
import ru.yandex.monlib.metrics.registry.MetricRegistry;
import ru.yandex.solomon.selfmon.failsafe.CircuitBreaker;

import static io.grpc.stub.ClientCalls.asyncBidiStreamingCall;
import static io.grpc.stub.ClientCalls.asyncServerStreamingCall;
import static io.grpc.stub.ClientCalls.asyncUnaryCall;

/**
 * @author Vladimir Gordiychuk
 */
@ParametersAreNonnullByDefault
public class GrpcTransport implements AutoCloseable, Transport {
    private static final Logger logger = LoggerFactory.getLogger(GrpcTransport.class);

    private final ManagedChannel channel;
    private final int maxOutboundMessageSize;
    private final HostAndPort address;
    private final long defaultTimeoutMs;
    private final MetricRegistry registry;
    private final CircuitBreaker circuitBreaker;

    public GrpcTransport(HostAndPort address, GrpcClientOptions opts) {
        this.address = address;
        this.channel = opts.getChannelFactory()
                .orElseGet(NettyChannelFactory::new)
                .createChannel(address, opts);
        this.maxOutboundMessageSize = opts.getMaxOutboundMessageSizeInBytes();
        this.defaultTimeoutMs = opts.getDefaultTimeoutMillis();
        this.registry = opts.getMetricRegistry();
        this.circuitBreaker = opts.getCircuitBreakerFactory()
                .orElseGet(DefaultCircuitBreakerFactory::new)
                .create(address, opts);
    }

    @Override
    public HostAndPort getAddress() {
        return address;
    }

    public <ReqT, RespT> CompletableFuture<RespT> unaryCall(MethodDescriptor<ReqT, RespT> method, ReqT request) {
        return unaryCall(method, request, 0);
    }

    public <ReqT, RespT> CompletableFuture<RespT> unaryCall(MethodDescriptor<ReqT, RespT> method, ReqT request, long deadline) {
        try {
            if (!circuitBreaker.attemptExecution()) {
                incCircuitBreakerFail(method.getFullMethodName());
                return CompletableFuture.failedFuture(circuitBreakerException());
            }

            CompletableFuture<RespT> future = new CompletableFuture<>();
            CallOptions callOptions = fillDeadline(CallOptions.DEFAULT, deadline);
            if (maxOutboundMessageSize > 0) {
                callOptions = callOptions.withMaxOutboundMessageSize(maxOutboundMessageSize);
            }
            asyncUnaryCall(channel.newCall(method, callOptions), request, new SingleResponseStreamObserver<>(future));
            return future.whenComplete((response, throwable) -> {
                Status status;
                if (throwable != null) {
                    status = statusFromThrowable(throwable);
                    logError(method, status, throwable);
                } else {
                    status = Status.OK;
                }

                reportStatusToCircuitBreaker(status);
            });
        } catch (Throwable e) {
            Status status = statusFromThrowable(e);
            reportStatusToCircuitBreaker(status);
            registry.rate("grpc.client.call.status", Labels.of("code", status.getCode().name())).inc();
            logger.error("{} - {} on node {}", status.getCode(), method.getFullMethodName(), address, e);
            return CompletableFuture.failedFuture(e);
        }
    }

    private Throwable circuitBreakerException() {
        return new StatusRuntimeExceptionNoStackTrace(Status.UNAVAILABLE.withDescription("CircuitBreaker#OPEN " + address));
    }

    public <ReqT, RespT> void logError(MethodDescriptor<ReqT, RespT> method, Status status, Throwable error) {
        if (isWarn(status)) {
            logger.warn("{} - {} on node {}", status.getCode(), method.getFullMethodName(), address);
        } else {
            Throwable cause = CompletableFutures.unwrapCompletionException(error);
            if (cause instanceof StatusRuntimeException) {
                logger.error("{} - {} on node {}", status.getCode(), method.getFullMethodName(), address);
            } else {
                logger.error("{} - {} on node {}", status.getCode(), method.getFullMethodName(), address, cause);
            }
        }
    }

    private static boolean isWarn(Status status) {
        switch (status.getCode()) {
            case DEADLINE_EXCEEDED:
            case UNAVAILABLE:
                return true;
            default:
                return false;
        }
    }

    @Override
    public boolean isReady() {
        if (circuitBreaker.getStatus() == CircuitBreaker.Status.OPEN) {
            return false;
        }

        return isConnected();
    }

    @Override
    public boolean isConnected() {
        ConnectivityState connectivity = channel.getState(true);
        return connectivity == ConnectivityState.READY || connectivity == ConnectivityState.IDLE;
    }

    private Status statusFromThrowable(Throwable throwable) {
        Status status = Status.fromThrowable(throwable);
        if (Strings.isNullOrEmpty(status.getDescription())) {
            return status.withDescription("Node " + address);
        } else {
            return status.withDescription(status.getDescription() + " node " + address);
        }
    }

    private void incCircuitBreakerFail(String endpoint) {
        registry.rate("grpc.client.call.status",
                Labels.of("endpoint", endpoint,
                        "code", "OPEN_CIRCUIT_BREAKER"))
                .inc();
    }

    private void incInFlightLimiterFail(String endpoint) {
        registry.rate("grpc.client.call.status",
                Labels.of("endpoint", endpoint,
                        "code", "IN_FLIGHT_LIMIT_REACHED"))
                .inc();
    }

    public <Request extends Message, Response extends Message> void serverStreamingCall(
            MethodDescriptor<Request, Response> method,
            Request request,
            Flow.Subscriber<Response> observer,
            long deadline)
    {
        if (!circuitBreaker.attemptExecution()) {
            incCircuitBreakerFail(method.getFullMethodName());
            observer.onError(circuitBreakerException());
            return;
        }

        try {
            CallOptions opts = deadline != 0
                    ? fillDeadline(CallOptions.DEFAULT, deadline)
                    : CallOptions.DEFAULT;
            if (maxOutboundMessageSize > 0) {
                opts = opts.withMaxOutboundMessageSize(maxOutboundMessageSize);
            }
            ClientCall<Request, Response> call = channel.newCall(method, opts);
            observer.onSubscribe(new Flow.Subscription() {
                @Override
                public void request(long n) {
                    // unsupported
                }

                @Override
                public void cancel() {
                    call.cancel("Manual cancellation", null);
                }
            });
            asyncServerStreamingCall(call, request, new StreamObserver<Response>() {
                @Override
                public void onNext(Response value) {
                    circuitBreaker.markSuccess();
                    observer.onNext(value);
                }

                @Override
                public void onError(Throwable t) {
                    reportStatusToCircuitBreaker(Status.fromThrowable(t));
                    observer.onError(t);
                }

                @Override
                public void onCompleted() {
                    circuitBreaker.markSuccess();
                    observer.onComplete();
                }
            });

        } catch (Throwable e) {
            Status status = Status.fromThrowable(e);
            reportStatusToCircuitBreaker(status);
            registry.rate("grpc.client.call.status", Labels.of("code", status.getCode().name())).inc();
            logger.error("{} - {} on node {} for request {}", status.getCode(), method.getFullMethodName(), address, request, e);
            observer.onError(e);
        }
    }

    public <Request extends Message, Response extends Message> StreamObserver<Request> bidiStreamingCall(
            MethodDescriptor<Request, Response> method,
            StreamObserver<Response> observer)
    {
        if (!circuitBreaker.attemptExecution()) {
            incCircuitBreakerFail(method.getFullMethodName());
            observer.onError(circuitBreakerException());
            return StreamObservers.noop();
        }

        try {
            CallOptions opts = CallOptions.DEFAULT;
            if (maxOutboundMessageSize > 0) {
                opts = opts.withMaxOutboundMessageSize(maxOutboundMessageSize);
            }
            ClientCall<Request, Response> call = channel.newCall(method, opts);
            return asyncBidiStreamingCall(call, new StreamObserver<>() {
                @Override
                public void onNext(Response value) {
                    circuitBreaker.markSuccess();
                    observer.onNext(value);
                }

                @Override
                public void onError(Throwable t) {
                    reportStatusToCircuitBreaker(Status.fromThrowable(t));
                    observer.onError(t);
                }

                @Override
                public void onCompleted() {
                    circuitBreaker.markSuccess();
                    observer.onCompleted();
                }
            });
        } catch (Throwable e) {
            Status status = Status.fromThrowable(e);
            reportStatusToCircuitBreaker(status);
            registry.rate("grpc.client.call.status", Labels.of("code", status.getCode().name())).inc();
            logger.error("{} - {} on node {} for request {}", status.getCode(), method.getFullMethodName(), address, e);
            observer.onError(e);
            return StreamObservers.noop();
        }
    }

    private void reportStatusToCircuitBreaker(Status status) {
        switch (status.getCode()) {
            case UNAVAILABLE:
            case DEADLINE_EXCEEDED:
            case INTERNAL:
            case UNAUTHENTICATED:
            case UNKNOWN:
                circuitBreaker.markFailure();
                return;
            default:
                circuitBreaker.markSuccess();
        }
    }

    private CallOptions fillDeadline(CallOptions opts, long deadline) {
        if (deadline == 0) {
            if (defaultTimeoutMs == 0) {
                return opts;
            } else {
                return opts.withDeadlineAfter(defaultTimeoutMs, TimeUnit.MILLISECONDS);
            }
        }

        long remainingMillis = deadline - System.currentTimeMillis();
        return opts.withDeadlineAfter(remainingMillis, TimeUnit.MILLISECONDS);
    }

    @Override
    public void close() {
        channel.shutdownNow();
    }

    @Override
    public String toString() {
        return "GrpcTransport{" +
                "address=" + address +
                '}';
    }
}
