package ru.yandex.direct.jobs.monitoring.system.source.tracelog

import com.fasterxml.jackson.core.JsonProcessingException
import com.google.common.collect.ImmutableList
import org.apache.commons.collections4.map.LinkedMap
import org.slf4j.LoggerFactory
import ru.yandex.direct.binlogbroker.logbroker_utils.reader.impl.LogbrokerBatchReaderImpl
import ru.yandex.direct.common.db.PpcPropertiesSupport
import ru.yandex.direct.common.db.PpcPropertyNames
import ru.yandex.direct.jobs.monitoring.system.processor.REAL_TIME_SERVICES
import ru.yandex.direct.tracing.data.TraceData
import ru.yandex.kikimr.persqueue.consumer.SyncConsumer
import ru.yandex.kikimr.persqueue.consumer.transport.message.inbound.data.MessageBatch
import java.time.Duration
import java.time.LocalDateTime
import java.time.ZoneOffset
import java.util.ArrayList
import java.util.Base64
import java.util.function.BiPredicate
import java.util.function.Supplier
import java.util.zip.ZipException

private const val MAX_TRACE_LOG_PARTS = 20

//допустимые сервисы для сбора трейсов. Например spanId ess-router-а меняется только с новом релизом
//другие сервисы добавлять после проверки
private val ALLOWED_NOT_FULL_TRACE_LOG_SERVICES = ImmutableList.builder<String>()
    .addAll(REAL_TIME_SERVICES)
    .build()

private val ALLOWED_METHOD_CHARS_REGEXP = Regex("[A-Za-z_./]+")

private const val DEFAULT_NOT_FULL_MAP_SIZE = 10_000
private const val DEFAULT_TRACE_LOG_LIMIT = 3_000

class TraceLogLogbrokerReader(
    logbrokerConsumerSupplier: Supplier<SyncConsumer>,
    logbrokerNoCommit: Boolean,
    ppcPropertiesSupport: PpcPropertiesSupport,
) : LogbrokerBatchReaderImpl<TraceLogMonitoringData>(logbrokerConsumerSupplier, logbrokerNoCommit) {

    //для хранения неполных трейслогов
    private var notFullTraceLog: LinkedMap<Long, TraceLogMonitoringData> = LinkedMap()

    private val traceLogNotFullMapSizeProperty = ppcPropertiesSupport
        .get(PpcPropertyNames.TRACE_LOG_LOGBROKER_MONITORING_NOT_FULL_SIZE, Duration.ofSeconds(10))
    private val messageLimitProperty = ppcPropertiesSupport
        .get(PpcPropertyNames.TRACE_LOG_LOGBROKER_MONITORING_LIMIT, Duration.ofSeconds(10))

    companion object {
        private val logger = LoggerFactory.getLogger(TraceLogLogbrokerReader::class.java)
    }

    override fun batchDeserialize(messageBatch: MessageBatch): List<TraceLogMonitoringData> {
        val data: MutableList<TraceLogMonitoringData> = ArrayList()
        for (messageData in messageBatch.messageData) {
            val message = try {
                String(messageData.decompressedData)
            } catch (e: java.lang.RuntimeException) {
                if (e.cause is ZipException) {
                    val encodedRawData = Base64.getEncoder().encodeToString(messageData.rawData)
                    logger.error("Couldn't decompress message from " +
                        "topic:${messageBatch.topic}, partition:${messageBatch.partition}, " +
                        "messageData.offset:${messageData.offset}, messageData.rawData.size:${messageData.rawData.size}, " +
                        "encodeBase64(messageData.rawData):$encodedRawData, skipping", e)
                    continue
                }
                throw e
            }
            try {
                val traceLogRows = message.split('\n').asSequence()
                    .filter { it.isNotEmpty() }
                    .map { mapToTraceLogMonitoringData(it) }
                    .filterNotNull()
                    .toList()
                data.addAll(traceLogRows)
            } catch (e: RuntimeException) {
                logger.error("wrong message $message")
                logger.error("Failed to parse message", e)
            }
        }
        return filterAndSaveNotFullTraceLogs(data)
    }

    override fun count(e: MutableList<TraceLogMonitoringData>): Int {
        return e.size
    }

    override fun batchingThreshold(): BiPredicate<Long, Duration> {
        val messageLimit = messageLimitProperty
            .getOrDefault(DEFAULT_TRACE_LOG_LIMIT)
        return BiPredicate { rows: Long, time: Duration -> rows < messageLimit && time.compareTo(iterationTime) < 0 }
    }

    private fun isValidMethod(method: String): Boolean {
        return ALLOWED_METHOD_CHARS_REGEXP.matches(method)
    }

    private fun mapToTraceLogMonitoringData(row: String): TraceLogMonitoringData? {
        val trace = parseTrace(row)
        if (trace == null || !isValidMethod(trace.method)) {
            return null
        }
        val functionsProfile = trace.profiles
            .map { profile ->
                TraceLogFunctionRow(name = profile.func,
                    tags = profile.tags, ela = profile.allEla,
                    calls = profile.calls.toInt(), objectNum = profile.objCount.toInt())
            }
        val dateTime = LocalDateTime.ofInstant(trace.logTime, ZoneOffset.UTC)
        val cpu = trace.times.cpuSystemTime + trace.times.cpuUserTime

        return TraceLogMonitoringData(dateTime,
            service = trace.service, method = trace.method, chunkIndex = trace.chunkIndex,
            isLastChunk = trace.isChunkFinal, traceId = trace.traceId, spanId = trace.spanId,
            host = trace.host, timeSpent = trace.allEla, cpu = cpu,
            memory = trace.times.mem, functionsProfile = functionsProfile)
    }

    private fun parseTrace(row: String): TraceData? {
        try {
            return TraceData.fromJson(row)
        } catch (ex: JsonProcessingException) {
            logger.error("wrong row $row")
            logger.error("Failed to parse message", ex)
            return null
        }
    }

    private fun filterAndSaveNotFullTraceLogs(dataList: List<TraceLogMonitoringData>): List<TraceLogMonitoringData> {
        val notFullMapSize = traceLogNotFullMapSizeProperty
            .getOrDefault(DEFAULT_NOT_FULL_MAP_SIZE)

        val resultList = mutableListOf<TraceLogMonitoringData>()
        for (row in dataList) {
            if (!ALLOWED_NOT_FULL_TRACE_LOG_SERVICES.contains(row.service)) {
                resultList.add(row)
            } else {
                if (row.chunkIndex == 1 && row.isLastChunk) {
                    resultList.add(row)
                } else if (row.chunkIndex == 1 && !row.isLastChunk || row.chunkIndex > 1) {
                    notFullTraceLog.compute(row.spanId) { _, oldRow ->
                        if (oldRow == null) row else mergeTraceLogs(oldRow, row)
                    }
                }
            }
        }
        notFullTraceLog.forEach { (_, log) ->
            if (log.isLastChunk) {
                resultList.add(log)
                logger.info("slow trace log in ${log.service} method ${log.method} spanId ${log.spanId}" +
                    " with spend time ${log.timeSpent}")

            } else if (LocalDateTime.now() > log.logTime.plusHours(4)) {
                logger.info("slow trace log not found last chunk $log")
            }
        }
        notFullTraceLog = notFullTraceLog
            .filter { (_, log) -> !log.isLastChunk }
            .toMap(LinkedMap())

        if (notFullTraceLog.size > notFullMapSize) {
            logger.error("cannot save more in notFullTraceLogs, limit exceeded")
            while (notFullTraceLog.size > notFullMapSize) {
                val deletedElement = notFullTraceLog.remove(notFullTraceLog.firstKey())!!
                logger.info("deleting first element in not full trace logs, ${deletedElement.spanId} " +
                    "chunk index ${deletedElement.chunkIndex}")
            }
        }

        return resultList.sortedBy { r -> r.logTime }
    }

    private fun mergeTraceLogs(first: TraceLogMonitoringData, second: TraceLogMonitoringData): TraceLogMonitoringData {
        if (first.chunkIndex >= second.chunkIndex) {
            logger.info("skip row $second index in first ${first.chunkIndex} more than second")
            return first
        }
        return first.copy(
            logTime = second.logTime,
            chunkIndex = second.chunkIndex,
            isLastChunk = second.isLastChunk || second.chunkIndex > MAX_TRACE_LOG_PARTS,
            //время в последующих чанках приходит суммой для всего spanId, а остальные поля нужно суммировать
            timeSpent = second.timeSpent,
            cpu = first.cpu + second.cpu,
            memory = first.memory + second.memory,
            functionsProfile = mergeFunctions(first.functionsProfile, second.functionsProfile)
        )
    }

    private fun mergeFunctions(first: List<TraceLogFunctionRow>,
                               second: List<TraceLogFunctionRow>): List<TraceLogFunctionRow> {
        val secondMap = second
            .map { r -> r.name to r.tags to r }
            .toMap()
        return first.map { row ->
            val secondValue = secondMap[row.name to row.tags]
            secondValue?.let { mergeFunction(row, secondValue) } ?: row
        }
    }

    private fun mergeFunction(first: TraceLogFunctionRow, second: TraceLogFunctionRow): TraceLogFunctionRow {
        return TraceLogFunctionRow(name = first.name, tags = first.tags, ela = first.ela + second.ela,
            calls = first.calls + second.calls, objectNum = first.objectNum + second.objectNum)
    }
}
