package ru.yandex.intranet.d.services.integration.jns

import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.databind.ObjectReader
import io.netty.channel.ChannelOption
import io.netty.handler.timeout.ReadTimeoutHandler
import io.netty.handler.timeout.WriteTimeoutHandler
import kotlinx.coroutines.reactor.awaitSingle
import kotlinx.coroutines.reactor.awaitSingleOrNull
import org.springframework.beans.factory.annotation.Value
import org.springframework.context.annotation.Profile
import org.springframework.http.HttpHeaders
import org.springframework.http.HttpStatus
import org.springframework.http.MediaType
import org.springframework.stereotype.Component
import org.springframework.web.reactive.function.client.ExchangeFilterFunctions
import org.springframework.web.reactive.function.client.WebClient
import org.springframework.web.reactive.function.client.WebClientResponseException
import org.springframework.web.reactive.function.client.createExceptionAndAwait
import org.springframework.web.util.UriComponentsBuilder
import reactor.core.Exceptions
import reactor.core.publisher.Mono
import reactor.netty.http.client.HttpClient
import reactor.util.retry.Retry
import reactor.util.retry.RetrySpec
import ru.yandex.intranet.d.kotlin.mono
import ru.yandex.intranet.d.util.ResolverHolder
import ru.yandex.intranet.d.util.http.YaReactorClientHttpConnector
import ru.yandex.intranet.d.web.security.tvm.ServiceTicketsCache
import ru.yandex.intranet.d.web.security.tvm.TvmClient
import java.time.Duration
import java.util.*
import java.util.concurrent.TimeUnit

/**
 * Jns client implementation.
 *
 * @author Dmitriy Timashov <dm-tim@yandex-team.ru>
 */
@Component
@Profile("dev", "testing", "production")
class JnsClientImpl(
    private val tvmClient: TvmClient,
    private val objectMapper: ObjectMapper,
    @Value("\${notifications.jns.tvmId}") private val jnsTvmId: Long,
    @Value("\${tvm.ownId}") private val ownTvmId: Long,
    @Value("\${notifications.jns.host}") private val jnsHost: String,
    @Value("\${notifications.jns.client.connectTimeoutMs}") connectTimeoutMillis: Int,
    @Value("\${notifications.jns.client.readTimeoutMs}") readTimeoutMillis: Long,
    @Value("\${notifications.jns.client.writeTimeoutMs}") writeTimeoutMillis: Long,
    @Value("\${notifications.jns.client.timeoutMs}") private val timeoutMillis: Long,
    @Value("\${notifications.jns.client.maxAttempts}") private val maxAttempts: Long,
    @Value("\${notifications.jns.client.minBackoffMs}") private val minBackoffMillis: Long,
    @Value("\${http.client.userAgent}") private val userAgent: String
): JnsClient {

    private val serviceTicketCache: ServiceTicketsCache = ServiceTicketsCache(ownTvmId, Duration.ofHours(1),
        10, 1, Duration.ofSeconds(3)) { source, destinations -> tvmClient.tickets(source, destinations) }
    private val errorReader: ObjectReader = objectMapper.readerFor(JnsResponse::class.java)
    private val webClient: WebClient

    init {
        webClient = WebClient.builder()
            .clientConnector(YaReactorClientHttpConnector(
                HttpClient.create()
                    .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeoutMillis)
                    .resolver(ResolverHolder.RESOLVER_INSTANCE)
                    .doOnConnected { connection ->
                        connection.addHandlerLast(ReadTimeoutHandler(readTimeoutMillis, TimeUnit.MILLISECONDS))
                        connection.addHandlerLast(WriteTimeoutHandler(writeTimeoutMillis, TimeUnit.MILLISECONDS))
                    }
            ))
            .filter(ExchangeFilterFunctions.limitResponseSize(10485760L))
            .codecs { configurer -> configurer.defaultCodecs().maxInMemorySize(Math.toIntExact(10485760L)) }
            .build()
    }

    override suspend fun send(message: JnsMessage): JnsResult {
        val tvmTicket = getTvmTicket()
        val requestId = UUID.randomUUID().toString()
        val uri = UriComponentsBuilder.fromHttpUrl(jnsHost)
            .pathSegment("api", "messages", "send_to_channel_json")
            .toUriString()
        return webClient.post()
            .uri(uri)
            .header("X-Ya-Service-Ticket", tvmTicket)
            .accept(MediaType.APPLICATION_JSON)
            .header(HttpHeaders.USER_AGENT, userAgent)
            .contentType(MediaType.APPLICATION_JSON)
            .bodyValue(JnsRequest(
                project = message.project,
                template = message.template,
                targetProject = message.targetProject,
                channel = message.channel,
                requestId = requestId,
                params = message.parameters
            ))
            .exchangeToMono { r ->
                mono {
                    if (r.rawStatusCode() == HttpStatus.OK.value() || r.rawStatusCode() == HttpStatus.CONFLICT.value()) {
                        r.releaseBody().awaitSingleOrNull()
                        return@mono JnsResult.success()
                    } else {
                        throw r.createExceptionAndAwait()
                    }
                }
            }
            .timeout(Duration.ofMillis(timeoutMillis))
            .retryWhen(retryRequest())
            .onErrorResume { e -> Mono.just(toResult(e)) }
            .awaitSingle()
    }

    private fun toResult(error: Throwable): JnsResult {
        return if (error is WebClientResponseException) {
            toResult(error)
        } else if (error.cause is WebClientResponseException) {
            toResult(error.cause as WebClientResponseException)
        } else if (Exceptions.isRetryExhausted(error) && error.cause != null) {
            JnsResult.failure(error.cause!!)
        } else {
            JnsResult.failure(error)
        }
    }

    private fun toResult(error: WebClientResponseException): JnsResult {
        val statusCode = error.rawStatusCode
        val (objectBody, textBody) = if (error.headers.contentType?.equalsTypeAndSubtype(MediaType.APPLICATION_JSON) == true) {
            Pair(errorReader.readValue<JnsResponse>(error.responseBodyAsString), null)
        } else if (error.headers.contentType?.equalsTypeAndSubtype(MediaType.TEXT_PLAIN) == true) {
            Pair(null, error.responseBodyAsString)
        } else {
            Pair(null, null)
        }
        return JnsResult.error(statusCode, objectBody, textBody)
    }

    private fun retryRequest(): Retry {
        return RetrySpec.backoff(maxAttempts, Duration.ofMillis(minBackoffMillis)).filter { e: Throwable ->
            return@filter if (e is WebClientResponseException) {
                e.rawStatusCode == HttpStatus.TOO_MANY_REQUESTS.value() || (e.rawStatusCode in 500..599)
            } else {
                !Exceptions.isRetryExhausted(e)
            }
        }
    }

    private suspend fun getTvmTicket(): String {
        return serviceTicketCache.getServiceTicket(jnsTvmId).awaitSingle()
    }

}
