package ru.yandex.intranet.d.grpc.interceptors;

import java.util.Locale;

import com.google.rpc.Code;
import com.google.rpc.Status;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.StatusRuntimeException;
import io.grpc.protobuf.StatusProto;
import net.devh.boot.grpc.common.util.InterceptorOrder;
import net.devh.boot.grpc.server.interceptor.GrpcGlobalServerInterceptor;
import net.devh.boot.grpc.server.security.interceptors.AuthenticatingServerInterceptor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.MessageSource;
import org.springframework.core.annotation.Order;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.ReactiveAuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import reactor.core.Disposable;
import reactor.core.publisher.Mono;

import ru.yandex.intranet.d.grpc.security.GrpcAuthenticationConverter;
import ru.yandex.intranet.d.grpc.security.PermissionDeniedGrpcRequestHandler;
import ru.yandex.intranet.d.grpc.security.UnauthenticatedGrpcRequestHandler;
import ru.yandex.intranet.d.i18n.Locales;
import ru.yandex.intranet.d.web.log.AccessLogAttributesProducer;

/**
 * GRPC authentication interceptor.
 * Only authenticated users are allowed.
 * Method-based security is not supported because global method security is required
 * which somehow conflicts with reactive method security
 *
 * @author Dmitriy Timashov <dm-tim@yandex-team.ru>
 */
@GrpcGlobalServerInterceptor
@Order(InterceptorOrder.ORDER_SECURITY_AUTHENTICATION)
public class YaAuthenticatingInterceptor implements AuthenticatingServerInterceptor {

    private static final Logger LOG = LoggerFactory.getLogger(YaAuthenticatingInterceptor.class);

    private final GrpcAuthenticationConverter authenticationConverter;
    private final ReactiveAuthenticationManager reactiveAuthenticationManager;
    private final UnauthenticatedGrpcRequestHandler unauthenticatedGrpcRequestHandler;
    private final PermissionDeniedGrpcRequestHandler permissionDeniedGrpcRequestHandler;
    private final MessageSource messages;

    public YaAuthenticatingInterceptor(
            @Qualifier("yaGrpcAuthenticationConverter") GrpcAuthenticationConverter authenticationConverter,
            ReactiveAuthenticationManager reactiveAuthenticationManager,
            UnauthenticatedGrpcRequestHandler unauthenticatedGrpcRequestHandler,
            PermissionDeniedGrpcRequestHandler permissionDeniedGrpcRequestHandler,
            @Qualifier("messageSource") MessageSource messages) {
        this.authenticationConverter = authenticationConverter;
        this.reactiveAuthenticationManager = reactiveAuthenticationManager;
        this.unauthenticatedGrpcRequestHandler = unauthenticatedGrpcRequestHandler;
        this.permissionDeniedGrpcRequestHandler = permissionDeniedGrpcRequestHandler;
        this.messages = messages;
    }

    @Override
    public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers,
                                                                 ServerCallHandler<ReqT, RespT> next) {
        Context context = Context.current();
        try {
            DelayedServerCallListener<ReqT> listener = new DelayedServerCallListener<>();
            Disposable subscription = authenticationConverter.readAuthentication(call, headers)
                    .switchIfEmpty(Mono.error(() -> new BadCredentialsException("No credentials found in the request")))
                    .flatMap(reactiveAuthenticationManager::authenticate)
                    .contextWrite(buildContextWithLogId(context))
                    .subscribe(
                            r -> forwardRequest(r, call, headers, next, listener, context),
                            e -> context.run(() -> {
                                try {
                                    if (e instanceof AuthenticationException) {
                                        unauthenticatedGrpcRequestHandler.onRequest(e, call, headers, next);
                                    } else if (e instanceof AccessDeniedException) {
                                        permissionDeniedGrpcRequestHandler.onRequest(e, call, headers, next);
                                    } else {
                                        onUnexpectedError(e, call);
                                    }
                                } catch (Exception ex) {
                                    LOG.error("Unexpected error in unauthenticated request handler", ex);
                                } finally {
                                    listener.setListener(new ServerCall.Listener<>() { });
                                }
                            })
                    );
            listener.setCancelCallback(subscription::dispose);
            return listener;
        } catch (Exception e) {
            onUnexpectedError(e, call);
            return new ServerCall.Listener<>() { };
        }
    }

    private reactor.util.context.Context buildContextWithLogId(Context context) {
        return reactor.util.context.Context.of(AccessLogAttributesProducer.LOG_ID,
                AccessLogInterceptor.LOG_ID_KEY.get(context));
    }

    private <ReqT, RespT> void forwardRequest(Authentication authentication, ServerCall<ReqT, RespT> call,
                                              Metadata headers, ServerCallHandler<ReqT, RespT> next,
                                              DelayedServerCallListener<ReqT> listener, Context context) {
        // TODO Maybe initialize SecurityContextHolder here? It's not immediately useful though
        //  because spring security annotations are not used for GRPC endpoints anyway...
        context.run(() -> listener.setListener(Contexts.interceptCall(
                Context.current().withValue(SECURITY_CONTEXT_KEY, new YaSecurityContext(authentication)),
                call, headers, next)));
    }

    private <ReqT, RespT> void onUnexpectedError(Throwable throwable, ServerCall<ReqT, RespT> call) {
        LOG.error("Unexpected error in authentication", throwable);
        Locale locale = Locales.ENGLISH;
        String message = messages.getMessage("errors.unexpected.service.error", null, locale);
        Status status = Status.newBuilder()
                .setCode(Code.UNKNOWN.getNumber())
                .setMessage(message)
                .build();
        final StatusRuntimeException error = StatusProto.toStatusRuntimeException(status);
        call.close(error.getStatus(), error.getTrailers() != null ? error.getTrailers() : new Metadata());
    }

}
