package ru.yandex.intranet.d.grpc;

import java.util.List;
import java.util.Locale;
import java.util.function.Function;

import com.salesforce.reactorgrpc.stub.ServerCalls;
import io.grpc.CallOptions;
import io.grpc.Context;
import io.grpc.Status;
import io.grpc.StatusException;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.StreamObserver;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.springframework.context.MessageSource;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import ru.yandex.intranet.d.datasource.model.BadSessionException;
import ru.yandex.intranet.d.datasource.model.SessionPoolDepletedException;
import ru.yandex.intranet.d.datasource.model.TransactionLocksInvalidatedException;
import ru.yandex.intranet.d.datasource.model.TransportUnavailableException;
import ru.yandex.intranet.d.grpc.interceptors.AccessLogInterceptor;
import ru.yandex.intranet.d.grpc.interceptors.ExceptionLoggingResponseObserver;
import ru.yandex.intranet.d.i18n.Locales;
import ru.yandex.intranet.d.web.log.AccessLogAttributesProducer;

/**
 * Reactive GRPC helpers.
 *
 * @author Dmitriy Timashov <dm-tim@yandex-team.ru>
 */
public final class Grpc {

    private Grpc() {
    }

    public static <TRequest, TResponse> void oneToOne(
            TRequest request, StreamObserver<TResponse> responseObserver,
            Function<Mono<TRequest>, Mono<TResponse>> delegate, MessageSource messages) {
        Context context = Context.current();
        ServerCalls.oneToOne(request, exceptionLoggingObserver(responseObserver, context), delegate
                .andThen(r -> r.onErrorMap(e -> mapError(e, messages)))
                .andThen(r -> r.contextWrite(buildContextWithLogId(context)))
        );
    }

    public static <TRequest, TResponse> void oneToMany(
            TRequest request, StreamObserver<TResponse> responseObserver,
            Function<Mono<TRequest>, Flux<TResponse>> delegate, MessageSource messages) {
        Context context = Context.current();
        ServerCalls.oneToMany(request, exceptionLoggingObserver(responseObserver, context), delegate
                .andThen(r -> r.onErrorMap(e -> mapError(e, messages)))
                .andThen(r -> r.contextWrite(buildContextWithLogId(context)))
        );
    }

    public static <TRequest, TResponse> StreamObserver<TRequest> manyToOne(
            StreamObserver<TResponse> responseObserver,
            Function<Flux<TRequest>, Mono<TResponse>> delegate,
            CallOptions options, MessageSource messages) {
        Context context = Context.current();
        return ServerCalls.manyToOne(exceptionLoggingObserver(responseObserver, context), delegate
                .andThen(r -> r.onErrorMap(e -> mapError(e, messages)))
                .andThen(r -> r.contextWrite(buildContextWithLogId(context))), options);
    }

    public static <TRequest, TResponse> StreamObserver<TRequest> manyToMany(
            StreamObserver<TResponse> responseObserver,
            Function<Flux<TRequest>, Flux<TResponse>> delegate,
            CallOptions options, MessageSource messages) {
        Context context = Context.current();
        return ServerCalls.manyToMany(exceptionLoggingObserver(responseObserver, context), delegate
                .andThen(r -> r.onErrorMap(e -> mapError(e, messages)))
                .andThen(r -> r.contextWrite(buildContextWithLogId(context))), options);
    }

    private static <TResponse> ExceptionLoggingResponseObserver<TResponse> exceptionLoggingObserver(
            StreamObserver<TResponse> responseObserver, Context context) {
        return new ExceptionLoggingResponseObserver<>(responseObserver, AccessLogInterceptor.LOG_ID_KEY.get(context));
    }

    private static reactor.util.context.Context buildContextWithLogId(Context context) {
        return reactor.util.context.Context.of(AccessLogAttributesProducer.LOG_ID,
                AccessLogInterceptor.LOG_ID_KEY.get(context));
    }

    private static Throwable mapError(Throwable error, MessageSource messages) {
        Locale locale = Locales.grpcLocale();
        if (error instanceof StatusException || error instanceof StatusRuntimeException) {
            return error;
        } else if (!hasStatus(error) && TransactionLocksInvalidatedException.isTransactionLocksInvalidated(error)) {
            String message = messages
                    .getMessage("errors.too.many.requests.transaction.locks.invalidated", null, locale);
            return Status.RESOURCE_EXHAUSTED.withCause(error).withDescription(message).asException();
        } else if (!hasStatus(error) && SessionPoolDepletedException.isSessionPoolDepleted(error)) {
            String message = messages
                    .getMessage("errors.too.many.requests", null, locale);
            return Status.RESOURCE_EXHAUSTED.withCause(error).withDescription(message).asException();
        } else if (!hasStatus(error) && TransportUnavailableException.isTransportUnavailable(error)) {
            String message = messages
                    .getMessage("errors.unavailable", null, locale);
            return Status.UNAVAILABLE.withCause(error).withDescription(message).asException();
        } else if (!hasStatus(error) && BadSessionException.isBadSession(error)) {
            String message = messages
                    .getMessage("errors.unavailable", null, locale);
            return Status.UNAVAILABLE.withCause(error).withDescription(message).asException();
        } else {
            Status status = Status.fromThrowable(error);
            if (Status.Code.UNKNOWN.equals(status.getCode())
                    || status.getDescription() == null || status.getDescription().isEmpty()) {
                String message = codeToMessage(status.getCode(), messages, locale);
                return status.withDescription(message).asException();
            } else {
                return status.asException();
            }
        }
    }

    private static boolean hasStatus(Throwable error) {
        List<Throwable> throwableList = ExceptionUtils.getThrowableList(error);
        return !throwableList.isEmpty() && throwableList.stream()
                .anyMatch(t -> t instanceof StatusException || t instanceof StatusRuntimeException);
    }

    private static String codeToMessage(Status.Code grpcCode, MessageSource messages, Locale locale) {
        return switch (grpcCode) {
            case OK -> messages.getMessage("errors.grpc.code.ok", null, locale);
            case CANCELLED -> messages.getMessage("errors.grpc.code.cancelled", null, locale);
            case UNKNOWN -> messages.getMessage("errors.grpc.code.unknown", null, locale);
            case INVALID_ARGUMENT -> messages.getMessage("errors.grpc.code.invalid.argument", null, locale);
            case DEADLINE_EXCEEDED -> messages.getMessage("errors.grpc.code.deadline.exceeded", null, locale);
            case NOT_FOUND -> messages.getMessage("errors.grpc.code.not.found", null, locale);
            case ALREADY_EXISTS -> messages.getMessage("errors.grpc.code.already.exists", null, locale);
            case PERMISSION_DENIED -> messages.getMessage("errors.grpc.code.permission.denied", null, locale);
            case UNAUTHENTICATED -> messages.getMessage("errors.grpc.code.unauthenticated", null, locale);
            case RESOURCE_EXHAUSTED -> messages.getMessage("errors.grpc.code.resource.exhausted", null, locale);
            case FAILED_PRECONDITION -> messages.getMessage("errors.grpc.code.failed.precondition", null, locale);
            case ABORTED -> messages.getMessage("errors.grpc.code.aborted", null, locale);
            case OUT_OF_RANGE -> messages.getMessage("errors.grpc.code.out.of.range", null, locale);
            case UNIMPLEMENTED -> messages.getMessage("errors.grpc.code.unimplemented", null, locale);
            case INTERNAL -> messages.getMessage("errors.grpc.code.internal", null, locale);
            case UNAVAILABLE -> messages.getMessage("errors.grpc.code.unavailable", null, locale);
            case DATA_LOSS -> messages.getMessage("errors.grpc.code.data.loss", null, locale);
        };
    }

}
