package ru.yandex.travel.commons.grpc;

import java.util.concurrent.CompletableFuture;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;

import com.google.common.collect.ImmutableMap;
import com.google.protobuf.InvalidProtocolBufferException;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.StatusException;
import io.grpc.stub.StreamObserver;
import org.slf4j.Logger;

import ru.yandex.travel.commons.proto.EErrorCode;
import ru.yandex.travel.commons.proto.ErrorException;
import ru.yandex.travel.commons.proto.ProtoUtils;
import ru.yandex.travel.commons.proto.TError;

public class ServerUtils {
    private static ImmutableMap<EErrorCode, Status.Code> ERROR_CODE_TO_STATUS_CODE =
            ImmutableMap.<EErrorCode, Status.Code>builder()
                    .put(EErrorCode.EC_OK, Status.Code.OK)
                    .put(EErrorCode.EC_GENERAL_ERROR, Status.Code.UNKNOWN)
                    .put(EErrorCode.EC_INVALID_ARGUMENT, Status.Code.INVALID_ARGUMENT)
                    .put(EErrorCode.EC_FAILED_PRECONDITION, Status.Code.FAILED_PRECONDITION)
                    .put(EErrorCode.EC_ABORTED, Status.Code.ABORTED)
                    .put(EErrorCode.EC_UNAVAILABLE, Status.Code.UNAVAILABLE)
                    .put(EErrorCode.EC_RESOURCE_EXHAUSTED, Status.Code.RESOURCE_EXHAUSTED)
                    .put(EErrorCode.EC_NOT_FOUND, Status.Code.NOT_FOUND)
                    .put(EErrorCode.EC_ALREADY_EXISTS, Status.Code.ALREADY_EXISTS)
                    .put(EErrorCode.EC_PERMISSION_DENIED, Status.Code.PERMISSION_DENIED)
                    .put(EErrorCode.EC_CALL_TO_IM_OVERLOADED, Status.Code.RESOURCE_EXHAUSTED)
                    .put(EErrorCode.EC_IM_RETRYABLE_ERROR, Status.Code.UNAVAILABLE)
                    .build();

    private ServerUtils() {
    }

    public static Status statusFromError(TError error) {
        return ERROR_CODE_TO_STATUS_CODE.getOrDefault(error.getCode(), Status.Code.UNKNOWN).toStatus()
                .withDescription(error.getMessage())
                .withCause(new ErrorException(error));
    }

    public static Metadata.Key<TError> METADATA_ERROR_KEY = Metadata.Key.of("ya-error-bin",
            new Metadata.BinaryMarshaller<>() {
                @Override
                public byte[] toBytes(TError value) {
                    return value.toByteArray();
                }

                @Override
                public TError parseBytes(byte[] serialized) {
                    try {
                        return TError.parseFrom(serialized);
                    } catch (InvalidProtocolBufferException e) {
                        return null;
                    }
                }
            });

    public static <ReqT, RspT> void synchronously(
            Logger logger,
            ReqT request,
            StreamObserver<RspT> observer,
            Function<ReqT, RspT> handler,
            Function<Throwable, StatusException> errorHandler
    ) {
        try {
            observer.onNext(handler.apply(request));
            observer.onCompleted();
        } catch (Exception e) {
            observer.onError(errorHandler.apply(e));
        }
    }

    public static <ReqT, RspT> void asynchronously(
            Logger logger,
            ReqT request,
            StreamObserver<RspT> observer,
            Function<ReqT, CompletableFuture<RspT>> handler,
            Function<Throwable, StatusException> errorHandler
    ) {
        try {
            handler.apply(request)
                    .handle((rsp, e) -> {
                        if (rsp != null) {
                            observer.onNext(rsp);
                            observer.onCompleted();
                        }
                        if (e != null) {
                            observer.onError(errorHandler.apply(e));
                        }
                        return null;
                    });
        } catch (Exception e) {
            observer.onError(errorHandler.apply(e));
        }
    }

    public static <ReqT, RspT> void synchronously(
            final Logger logger,
            final ReqT request,
            StreamObserver<RspT> observer,
            Function<ReqT, RspT> handler
    ) {
        synchronously(logger, request, observer, handler, ex -> defaultMapException(logger, request, ex));
    }

    /**
     * Grpc uses the NettyServerStream and WriteQueue classes that internally rely on an unbound frames queue.
     * That queue can easily grow beyond available memory and cause OOM exceptions in case of large output data.
     * There is no easy way to flush that queue manually or wait for completion of its async flusher.
     */
    @Deprecated
    public static <ReqT, RspT> void streamingCall(
            Logger logger,
            ReqT request,
            StreamObserver<RspT> observer,
            BiConsumer<ReqT, Consumer<RspT>> handler
    ) {
        try {
            handler.accept(request, observer::onNext);
            observer.onCompleted();
        } catch (Exception e) {
            observer.onError(defaultMapException(logger, request, e));
        }
    }

    public static <ReqT> StatusException defaultMapException(Logger logger, ReqT request, Throwable exception) {
        logger.error("Caught exception {} while handling request {}", exception.getClass().getSimpleName(),
                request.getClass().getSimpleName(), exception);
        // TODO(tivelkov): passing verbose=true causes stack trace to get into headers, thus causing 'Header size
        //  exceeded max allowed size'
        TError error = ProtoUtils.errorFromThrowable(exception, false);
        Status status = statusFromError(error);
        Metadata trailers = new Metadata();
        trailers.put(ServerUtils.METADATA_ERROR_KEY, error);
        return status.asException(trailers);
    }
}
