package ru.yandex.mail.micronaut.tvm.auth;

import io.micronaut.http.HttpRequest;
import io.micronaut.security.rules.SecurityRule;
import io.micronaut.security.rules.SecurityRuleResult;
import io.micronaut.web.router.MethodBasedRouteMatch;
import io.micronaut.web.router.RouteMatch;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import one.util.streamex.EntryStream;
import one.util.streamex.StreamEx;
import ru.yandex.mail.micronaut.tvm.TvmConfiguration;
import ru.yandex.mail.micronaut.tvm.client.TvmTicket.ServiceTvmTicket;

import javax.annotation.Nullable;
import javax.inject.Singleton;
import java.util.Map;
import java.util.Set;

@Slf4j
@Singleton
public class TvmAnnotationRule implements SecurityRule {
    private final Set<Long> allowedClientsIds;

    public TvmAnnotationRule(TvmConfiguration tvmConfiguration) {
        allowedClientsIds = EntryStream.of(tvmConfiguration.getAllowedClients())
            .invert()
            .keys()
            .toImmutableSet();
    }

    @SuppressWarnings("unchecked")
    private static <T> T getClaim(Map<String, Object> claims, String name) {
        return (T) claims.get(name);
    }

    private static boolean isUserTicketRequired(MethodBasedRouteMatch methodRoute) {
        return StreamEx.of(methodRoute.getArguments())
            .anyMatch(arg -> arg.getType() == TvmUid.class);
    }

    private static boolean isUserTicketPresent(Map<String, Object> claims) {
        return getClaim(claims, TvmAuthentication.USER_TICKET_PROPERTY) != null;
    }

    @Override
    public SecurityRuleResult check(HttpRequest request, @Nullable RouteMatch routeMatch,
                                    @Nullable Map<String, Object> claims) {
        if (!(routeMatch instanceof MethodBasedRouteMatch)) {
            log.warn("Can't check TVM credentials for '{}' request", request.getUri());
            return SecurityRuleResult.UNKNOWN;
        }

        val methodRoute = ((MethodBasedRouteMatch) routeMatch);
        val annotation = methodRoute.getAnnotation(TvmSecured.class);
        if (annotation == null) {
            return SecurityRuleResult.ALLOWED;
        }

        if (claims == null) {
            log.info("TVM credentials not found for '{}' request", request.getUri());
            return SecurityRuleResult.REJECTED;
        }

        final ServiceTvmTicket serviceTicket = getClaim(claims, TvmAuthentication.SERVICE_TICKET_PROPERTY);
        if (serviceTicket == null) {
            log.info("TVM service ticket not found for '{}' request", request.getUri());
            return SecurityRuleResult.REJECTED;
        }

        val tvmId = serviceTicket.getTvmId();
        if (!allowedClientsIds.contains(tvmId)) {
            log.info("Service tvm id '{}' is forbidden for '{}' request", tvmId, request.getUri());
            return SecurityRuleResult.REJECTED;
        }

        if (isUserTicketRequired(methodRoute)) {
            if (isUserTicketPresent(claims)) {
                return SecurityRuleResult.ALLOWED;
            } else {
                log.info("Service tvm id '{}' is forbidden due to lack of user ticket for '{}' request", tvmId, request.getUri());
                return SecurityRuleResult.REJECTED;
            }
        } else {
            return SecurityRuleResult.ALLOWED;
        }
    }
}
