package ru.yandex.intranet.d.services.integration.solomon

import com.google.protobuf.InvalidProtocolBufferException
import com.google.rpc.BadRequest
import io.grpc.CallOptions
import io.grpc.Status
import io.grpc.StatusException
import io.grpc.StatusRuntimeException
import io.grpc.protobuf.StatusProto
import io.grpc.stub.StreamObserver
import kotlinx.coroutines.reactor.awaitSingle
import mu.KotlinLogging
import org.springframework.beans.factory.annotation.Value
import org.springframework.context.annotation.Profile
import org.springframework.stereotype.Component
import reactor.core.Exceptions
import reactor.core.publisher.Mono
import reactor.util.retry.Retry
import reactor.util.retry.RetrySpec
import ru.yandex.intranet.d.services.integration.providers.ProvidersIntegrationService
import ru.yandex.intranet.d.services.integration.providers.RequestIdSupplier
import ru.yandex.intranet.d.services.integration.providers.grpc.GrpcClient
import ru.yandex.intranet.d.services.integration.providers.grpc.RequestIdHolder
import ru.yandex.intranet.d.services.integration.providers.grpc.RequestIdInterceptor
import ru.yandex.intranet.d.services.integration.providers.grpc.TvmCallCredentials
import ru.yandex.intranet.d.util.AsyncMetrics
import ru.yandex.intranet.d.web.security.tvm.ServiceTicketsCache
import ru.yandex.intranet.d.web.security.tvm.TvmClient
import ru.yandex.monitoring.api.v3.MetricsDataServiceGrpc
import ru.yandex.monitoring.api.v3.ReadMetricsDataRequest
import ru.yandex.monitoring.api.v3.ReadMetricsDataResponse
import ru.yandex.monlib.metrics.histogram.Histograms
import ru.yandex.monlib.metrics.labels.Labels
import ru.yandex.monlib.metrics.registry.MetricRegistry
import java.io.UncheckedIOException
import java.time.Duration
import java.util.concurrent.TimeUnit
import java.util.function.Consumer

private val logger = KotlinLogging.logger {}

/**
 * Solomon client implementation.
 *
 * @author Dmitriy Timashov <dm-tim@yandex-team.ru>
 */
@Component
@Profile("dev", "testing", "production")
class SolomonClientImpl(
    private val tvmClient: TvmClient,
    @Value("\${solomon.api.tvmId}") private val solomonTvmId: Long,
    @Value("\${tvm.ownId}") private val ownTvmId: Long,
    @Value("\${solomon.api.client.deadlineAfterMs}") private val deadlineAfterMillis: Long,
    @Value("\${solomon.api.client.responseSizeLimitBytes}") private val responseSizeLimitBytes: Int,
    @Value("\${solomon.api.client.timeoutMs}") private val timeoutMillis: Long,
    @Value("\${solomon.api.client.maxAttempts}") private val maxAttempts: Long,
    @Value("\${solomon.api.client.minBackoffMs}") private val minBackoffMillis: Long
): SolomonClient {
    private val serviceTicketCache: ServiceTicketsCache = ServiceTicketsCache(ownTvmId, Duration.ofHours(1),
        10, 1, Duration.ofSeconds(3)) { source, destinations -> tvmClient.tickets(source, destinations) }
    @net.devh.boot.grpc.client.inject.GrpcClient(value = "solomonClient", interceptors = [RequestIdInterceptor::class])
    private lateinit var metricsDataService: MetricsDataServiceGrpc.MetricsDataServiceStub
    private val readDataRate = MetricRegistry.root().rate("solomon.api.requests.rate",
        Labels.of("solomon_op", "readData", "solomon_result", "any"))
    private val readDataErrorRate = MetricRegistry.root().rate("solomon.api.requests.rate",
        Labels.of("solomon_op", "readData", "solomon_result", "failure"))
    private val readDataDuration = MetricRegistry.root().histogramRate("solomon.api.requests.duration_millis",
        Labels.of("solomon_op", "readData"), Histograms.exponential(22, 2.0, 1.0))

    override suspend fun readData(request: ReadMetricsDataRequest): SolomonResponse<ReadMetricsDataResponse> {
        val tvmTicket = getTvmTicket()
        return doGrpcCall(tvmTicket,
            { stub, observer: StreamObserver<ReadMetricsDataResponse> -> stub.read(request, observer) },
            { duration, success ->
                readDataRate.inc()
                if (!success) {
                    readDataErrorRate.inc()
                }
                readDataDuration.record(duration)
            })
    }

    private suspend fun <T> doGrpcCall(
        tvmTicket: String,
        call: (stub: MetricsDataServiceGrpc.MetricsDataServiceStub, observer: StreamObserver<T>) -> Unit,
        metricConsumer: (duration: Long, success: Boolean) -> Unit): SolomonResponse<T> {
        return Mono.fromSupplier { RequestIdSupplier() }
            .flatMap { requestIdSupplier ->
                AsyncMetrics.metric(Mono.fromSupplier { RequestIdHolder() }
                    .flatMap { requestIdHolder -> GrpcClient.oneToOne(Mono.just(metricsDataService),
                        { stub, observer: StreamObserver<T> ->
                            call(prepareCallOptions(stub, tvmTicket, requestIdSupplier, requestIdHolder), observer)
                        }, CallOptions.DEFAULT)
                        .doOnError { error ->
                            val response = toResponse<Unit>(error, requestIdSupplier)
                            logger.error(error) {"Solomon GRPC request error: $response"}
                        }.map { r -> SolomonResponse.success(r, requestIdSupplier.lastId, requestIdHolder.requestId) }
                    }.timeout(Duration.ofMillis(timeoutMillis)), metricConsumer)
                    .retryWhen(retryRequest())
                    .onErrorResume { error -> Mono.just(toResponse(error, requestIdSupplier)) }
            }.awaitSingle()
    }

    private suspend fun getTvmTicket(): String {
        return serviceTicketCache.getServiceTicket(solomonTvmId).awaitSingle()
    }

    private fun prepareCallOptions(stub: MetricsDataServiceGrpc.MetricsDataServiceStub,
                                   tvmTicket: String,
                                   requestIdSupplier: RequestIdSupplier,
                                   requestIdHolder: RequestIdHolder): MetricsDataServiceGrpc.MetricsDataServiceStub {
        return stub
            .withMaxInboundMessageSize(responseSizeLimitBytes)
            .withDeadlineAfter(deadlineAfterMillis, TimeUnit.MILLISECONDS)
            .withCallCredentials(TvmCallCredentials(tvmTicket))
            .withOption(RequestIdInterceptor.REQUEST_ID_KEY, requestIdSupplier)
            .withOption(RequestIdInterceptor.REQUEST_ID_HOLDER_KEY, requestIdHolder)
    }

    private fun retryRequest(): Retry {
        return RetrySpec.backoff(maxAttempts, Duration.ofMillis(minBackoffMillis)).filter { e ->
            when (e) {
                is StatusRuntimeException -> isRetryableCode(e.status.code)
                is StatusException -> isRetryableCode(e.status.code)
                else -> !Exceptions.isRetryExhausted(e)
            }
        }
    }

    private fun isRetryableCode(statusCode: Status.Code): Boolean {
        return statusCode == Status.Code.INTERNAL
            || statusCode == Status.Code.RESOURCE_EXHAUSTED
            || statusCode == Status.Code.ABORTED
            || statusCode == Status.Code.UNAVAILABLE
            || statusCode == Status.Code.DEADLINE_EXCEEDED
            || statusCode == Status.Code.UNKNOWN
    }

    private fun <T> toResponse(ex: Throwable, requestIdSupplier: RequestIdSupplier): SolomonResponse<T> {
        val outgoingRequestId = requestIdSupplier.lastId
        when (ex) {
            is StatusRuntimeException -> return toResponse(ex, outgoingRequestId)
            is StatusException -> return toResponse(ex, outgoingRequestId)
        }
        val cause = ex.cause
        when (cause) {
            is StatusRuntimeException -> return toResponse(cause, outgoingRequestId)
            is StatusException -> return toResponse(cause, outgoingRequestId)
        }
        return if (Exceptions.isRetryExhausted(ex) && cause != null) {
            SolomonResponse.failure(cause, outgoingRequestId)
        } else {
            SolomonResponse.failure(ex, outgoingRequestId)
        }
    }

    private fun <T> toResponse(ex: StatusRuntimeException, outgoingRequestId: String?): SolomonResponse<T> {
        return toResponse(ex.status, StatusProto.fromThrowable(ex), outgoingRequestId,
            ex.trailers?.get(ProvidersIntegrationService.REQUEST_ID_KEY))
    }

    private fun <T> toResponse(ex: StatusException, outgoingRequestId: String?): SolomonResponse<T> {
        return toResponse(ex.status, StatusProto.fromThrowable(ex), outgoingRequestId,
            ex.trailers?.get(ProvidersIntegrationService.REQUEST_ID_KEY))
    }

    private fun <T> toResponse(status: Status, statusProto: com.google.rpc.Status?,
                               outgoingRequestId: String?, incomingRequestId: String?): SolomonResponse<T> {
        val badRequests = statusProto?.detailsList?.filter { any -> any.`is`(BadRequest::class.java) }
            ?.map { any -> unpackBadRequest(any) }?.toList() ?: emptyList()
        if (badRequests.isNotEmpty()) {
            val badRequestDetails = mutableMapOf<String, String>()
            badRequests.forEach(Consumer { r: BadRequest ->
                r.fieldViolationsList
                    .forEach(Consumer { v: BadRequest.FieldViolation -> badRequestDetails[v.field] = v.description })
            })
            return SolomonResponse.error(status.code, status.description, badRequestDetails,
                outgoingRequestId, incomingRequestId)
        }
        return SolomonResponse.error(status.code, status.description, null,
            outgoingRequestId, incomingRequestId)
    }

    private fun unpackBadRequest(any: com.google.protobuf.Any): BadRequest {
        return try {
            any.unpack(BadRequest::class.java)
        } catch (e: InvalidProtocolBufferException) {
            throw UncheckedIOException(e)
        }
    }

}
