package ru.yandex.intranet.d.grpc.interceptors

import com.google.rpc.Code
import com.google.rpc.Status
import io.grpc.Metadata
import io.grpc.ServerCall
import io.grpc.ServerCallHandler
import io.grpc.ServerInterceptor
import io.grpc.protobuf.StatusProto
import net.devh.boot.grpc.common.util.InterceptorOrder
import net.devh.boot.grpc.server.interceptor.GrpcGlobalServerInterceptor
import org.springframework.beans.factory.annotation.Qualifier
import org.springframework.context.MessageSource
import org.springframework.core.annotation.Order
import ru.yandex.intranet.d.i18n.Locales
import ru.yandex.intranet.d.util.bucket.RateLimiter
import ru.yandex.intranet.d.util.bucket.RateLimiterBreaker

@GrpcGlobalServerInterceptor
@Order(InterceptorOrder.ORDER_FIRST + 1)
class RateLimitingInterceptor(
    @Qualifier("publicGrpcApiRateLimiter") private val rateLimiter: RateLimiter,
    @Qualifier("messageSource") private val messages: MessageSource,
    private val rateLimiterBreaker: RateLimiterBreaker
): ServerInterceptor {

    override fun <ReqT, RespT> interceptCall(call: ServerCall<ReqT, RespT>,
                                             headers: Metadata,
                                             next: ServerCallHandler<ReqT, RespT>
    ): ServerCall.Listener<ReqT> {
        if ("grpc.health.v1.Health/Check" == call.methodDescriptor.fullMethodName) {
            // Skip rate limiting for health check
            return next.startCall(call, headers)
        }
        if (!rateLimiterBreaker.rateLimiterEnabled()) {
            // Skip rate limiter if disabled
            return next.startCall(call, headers)
        }
        val available = rateLimiter.tryConsume()
        return if (!available) {
            val locale = Locales.ENGLISH
            val message = messages.getMessage("errors.grpc.code.resource.exhausted", null, locale)
            val status = Status.newBuilder()
                .setCode(Code.RESOURCE_EXHAUSTED.number)
                .setMessage(message)
                .build()
            val error = StatusProto.toStatusRuntimeException(status)
            val trailers = if (error.trailers != null) {
                error.trailers
            } else {
                Metadata()
            }
            call.close(error.status, trailers)
            object : ServerCall.Listener<ReqT>() {}
        } else {
            next.startCall(call, headers)
        }
    }

}
