package ru.yandex.solomon.auth.grpc;

import java.util.Collection;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import javax.annotation.ParametersAreNonnullByDefault;

import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.ServiceDescriptor;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;

import ru.yandex.misc.concurrent.CompletableFutures;
import ru.yandex.solomon.auth.AnonymousAuthSubject;
import ru.yandex.solomon.auth.AuthSubject;
import ru.yandex.solomon.auth.AuthToken;
import ru.yandex.solomon.auth.Authenticator;

/**
 * @author Oleg Baryshnikov
 */
@ParametersAreNonnullByDefault
public class AuthenticationInterceptor implements ServerInterceptor {
    public static final Context.Key<AuthSubject> AUTH_CONTEXT_KEY = Context.key("AuthSubject");
    public static final Status UNAUTHENTICATED_STATUS =
            Status.UNAUTHENTICATED.withDescription("cannot authenticate request");

    private final Authenticator authenticator;

    // Auth must be used everywhere. For legacy projects (yasm) support method whitelist.
    private final Set<String> optionalAuthMethods;

    public AuthenticationInterceptor(Authenticator authenticator) {
        this(authenticator, Stream.empty());
    }

    public AuthenticationInterceptor(Authenticator authenticator, Stream<MethodDescriptor<?, ?>> optionalAuthMethods) {
        this.authenticator = authenticator;
        this.optionalAuthMethods = optionalAuthMethods
                .map(MethodDescriptor::getFullMethodName)
                .collect(Collectors.toSet());
    }

    public AuthenticationInterceptor(Authenticator authenticator, Collection<ServiceDescriptor> optionalAuthServices) {
        this(authenticator, optionalAuthServices.stream().flatMap(service -> service.getMethods().stream()));
    }

    @Override
    public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
        ServerCall<ReqT, RespT> call,
        Metadata headers,
        ServerCallHandler<ReqT, RespT> next)
    {
        try {
            String methodName = call.getMethodDescriptor().getFullMethodName();
            AuthSubject authSubject;

            if (optionalAuthMethods.contains(methodName)) {
                authSubject = AnonymousAuthSubject.INSTANCE;
            } else {
                Optional<AuthToken> token = authenticator.getToken(headers, call.getAttributes());
                if (token.isEmpty()) {
                    throw new StatusRuntimeException(UNAUTHENTICATED_STATUS);
                }

                authSubject = authenticator.authenticate(token.get()).join();
            }

            return callImpl(call, headers, next, authSubject);
        } catch (Exception e) {
            Throwable cause = CompletableFutures.unwrapCompletionException(e);
            throw new StatusRuntimeException(Status.UNAUTHENTICATED.withDescription(cause.getMessage()));
        }
    }

    private <ReqT, RespT> ServerCall.Listener<ReqT> callImpl(
        ServerCall<ReqT, RespT> call,
        Metadata headers,
        ServerCallHandler<ReqT, RespT> next,
        AuthSubject authSubject)
    {
        Context authContext =
            Context.current().withValue(AUTH_CONTEXT_KEY, authSubject);

        return Contexts.interceptCall(authContext, call, headers, next);
    }

    public static AuthSubject getAuthSubject(Context context) {
        AuthSubject authSubject = AUTH_CONTEXT_KEY.get(context);
        if (authSubject == null) {
            throw new StatusRuntimeException(UNAUTHENTICATED_STATUS);
        }
        return authSubject;
    }

    public static AuthSubject getAuthSubjectOrNull(Context context) {
        return AUTH_CONTEXT_KEY.get(context);
    }
}
