package ru.yandex.solomon.auth;

import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import io.grpc.Attributes;
import io.grpc.Metadata;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.server.reactive.ServerHttpRequest;

import ru.yandex.misc.concurrent.CompletableFutures;
import ru.yandex.monlib.metrics.primitives.Rate;
import ru.yandex.monlib.metrics.registry.MetricRegistry;
import ru.yandex.solomon.auth.exceptions.AuthenticationException;
import ru.yandex.solomon.staffOnly.annotations.HideFromManagerUi;


/**
 * @author Sergey Polovko
 */
final class AuthenticatorCache implements Authenticator {

    private static final Logger logger = LoggerFactory.getLogger(AuthorizerCache.class);

    private static final long CACHE_POSITIVE_TTL_MINUTES = 10;
    private static final long CACHE_NEGATIVE_TTL_SECONDS = 20;
    private static final long CACHE_MAX_SIZE = 6000;

    private final Authenticator delegate;
    @HideFromManagerUi
    private final Cache<AuthToken, AuthSubject> positiveCache;
    @HideFromManagerUi
    private final Cache<AuthToken, Throwable> negativeCache;
    private final Rate cacheHitOk;
    private final Rate cacheHitError;
    private final Rate cacheMiss;

    AuthenticatorCache(Authenticator delegate) {
        this.delegate = delegate;
        this.positiveCache = CacheBuilder.newBuilder()
            .expireAfterWrite(CACHE_POSITIVE_TTL_MINUTES, TimeUnit.MINUTES)
            .maximumSize(CACHE_MAX_SIZE)
            .build();
        this.negativeCache = CacheBuilder.newBuilder()
            .expireAfterWrite(CACHE_NEGATIVE_TTL_SECONDS, TimeUnit.SECONDS)
            .maximumSize(CACHE_MAX_SIZE)
            .build();
        this.cacheHitOk = MetricRegistry.root().rate("auth.cacheHitOk");
        this.cacheHitError = MetricRegistry.root().rate("auth.cacheHitError");
        this.cacheMiss = MetricRegistry.root().rate("auth.cacheMiss");
    }

    @Override
    public Optional<AuthToken> getToken(ServerHttpRequest request) {
        return delegate.getToken(request);
    }

    @Override
    public Optional<AuthToken> getToken(Metadata headers, Attributes attributes) {
        return delegate.getToken(headers, attributes);
    }

    @Override
    public CompletableFuture<AuthSubject> authenticate(AuthToken token) {
        AuthSubject subject = positiveCache.getIfPresent(token);
        if (subject != null) {
            cacheHitOk.inc();
            return CompletableFuture.completedFuture(subject);
        }

        Throwable error = negativeCache.getIfPresent(token);
        if (error != null) {
            cacheHitError.inc();
            String msg = "previous authentication error is not yet expired" + (StringUtils.isEmpty(error.getMessage()) ? "" : ": " + error.getMessage());
            return CompletableFuture.failedFuture(new AuthenticationException(msg, error));
        }

        cacheMiss.inc();
        return delegate.authenticate(token)
            .whenComplete((s, t) -> {
                if (t != null) {
                    Throwable cause = CompletableFutures.unwrapCompletionException(t);
                    logger.warn("cannot authenticate {}", token.getType(), cause);
                    if (ignoreForNegativeCache(cause)) {
                        return;
                    }
                    negativeCache.put(token, cause);
                } else {
                    positiveCache.put(token, s);
                }
            });
    }

    private static boolean ignoreForNegativeCache(Throwable cause) {
        if (cause instanceof AuthenticationException) {
            AuthenticationException authenticationException = (AuthenticationException) cause;
            return authenticationException.isNeedToResetSessionId() || !authenticationException.getRedirectTo().isEmpty();
        }
        return false;
    }
}
