package ru.yandex.intranet.d.web.util

import org.springframework.beans.factory.ObjectProvider
import org.springframework.beans.factory.annotation.Qualifier
import org.springframework.context.MessageSource
import org.springframework.core.Ordered
import org.springframework.http.HttpStatus
import org.springframework.http.MediaType
import org.springframework.http.codec.HttpMessageWriter
import org.springframework.http.codec.ServerCodecConfigurer
import org.springframework.stereotype.Component
import org.springframework.web.reactive.function.BodyInserters
import org.springframework.web.reactive.function.server.ServerResponse
import org.springframework.web.reactive.result.view.ViewResolver
import org.springframework.web.server.ServerWebExchange
import org.springframework.web.server.WebFilter
import org.springframework.web.server.WebFilterChain
import reactor.core.publisher.Mono
import ru.yandex.intranet.d.i18n.Locales
import ru.yandex.intranet.d.util.bucket.RateLimiter
import ru.yandex.intranet.d.util.bucket.RateLimiterBreaker
import java.util.stream.Collectors

@Component
class RateLimitingFilter(
    @Qualifier("publicRestApiRateLimiter") private val publicApiRateLimiter: RateLimiter,
    @Qualifier("frontApiRateLimiter") private val frontApiRateLimiter: RateLimiter,
    @Qualifier("messageSource") private val messages: MessageSource,
    viewResolvers: ObjectProvider<ViewResolver>,
    serverCodecConfigurer: ServerCodecConfigurer,
    private val rateLimiterBreaker: RateLimiterBreaker
): WebFilter, Ordered {
    private val messageWriters: List<HttpMessageWriter<*>> = serverCodecConfigurer.writers
    private val viewResolvers: List<ViewResolver> = viewResolvers.orderedStream().collect(Collectors.toList())
    private val pathsToExclude: Set<String> = setOf("/ping", "/local/liveness", "/local/readiness", "/sensors/metrics")

    override fun filter(exchange: ServerWebExchange, chain: WebFilterChain): Mono<Void> {
        val path = exchange.request.path.value()
        if (pathsToExclude.any { path.startsWith(it) }) {
            // Exclude ping, readiness and liveness from rate limiting
            return chain.filter(exchange)
        }
        if (!rateLimiterBreaker.rateLimiterEnabled()) {
            // Skip rate limiter if disabled
            return chain.filter(exchange)
        }
        val limiter = if (path.startsWith("/front")) {
            // Separate limiter for front API
            frontApiRateLimiter
        } else {
            // Public API limiter for all the rest endpoints
            publicApiRateLimiter
        }
        val available = limiter.tryConsume()
        return if (available) {
            chain.filter(exchange)
        } else {
            val errorStatus = HttpStatus.TOO_MANY_REQUESTS
            val locale = exchange.localeContext.locale ?: Locales.ENGLISH
            val errorAttributes = linkedMapOf("status" to errorStatus.value(),
                "error" to errorStatus.reasonPhrase,
                "message" to messages.getMessage("errors.too.many.requests", null, locale))
            ServerResponse.status(errorStatus).contentType(MediaType.APPLICATION_JSON)
                .body(BodyInserters.fromValue<Map<String, Any>>(errorAttributes)).flatMap { response ->
                    exchange.response.headers.contentType = response.headers().contentType
                    response.writeTo(exchange, ResponseContext(messageWriters, viewResolvers))
                }
        }
    }

    override fun getOrder(): Int {
        return Ordered.HIGHEST_PRECEDENCE + 1
    }

    private class ResponseContext(
        private val messageWriters: List<HttpMessageWriter<*>>,
        private val viewResolvers: List<ViewResolver>
    ): ServerResponse.Context {
        override fun messageWriters(): List<HttpMessageWriter<*>> {
            return messageWriters
        }
        override fun viewResolvers(): List<ViewResolver> {
            return viewResolvers
        }
    }

}
