package ru.yandex.intranet.d.services.usage

import ru.yandex.intranet.d.model.units.UnitModel
import ru.yandex.intranet.d.model.usage.HistogramBin
import ru.yandex.intranet.d.util.units.Units
import java.math.BigDecimal
import java.math.BigInteger
import java.math.MathContext
import java.math.RoundingMode
import kotlin.math.ceil
import kotlin.math.log2

/**
 * Convert time series from solomon units to integer base units
 */
fun convertTimeSeries(source: Map<Long, BigDecimal>,
                      sourceUnit: UnitModel, destinationUnit: UnitModel): Map<Long, BigInteger> {
    return source.mapValues { entry ->
        Units.convert(entry.value, sourceUnit, destinationUnit).setScale(0, RoundingMode.HALF_UP).toBigInteger()
    }
}

/**
 * Sum time series
 */
fun sumTimeSeries(timeSeries: Collection<Map<Long, BigInteger>>): Map<Long, BigInteger> {
    if (timeSeries.isEmpty()) {
        return mapOf()
    }
    if (timeSeries.size == 1) {
        return timeSeries.first()
    }
    val result = mutableMapOf<Long, BigInteger>()
    timeSeries.forEach {
        it.forEach { (key, value) ->
            result.compute(key) { _, oldValue ->
                if (oldValue == null) {
                    value
                } else {
                    oldValue.add(value)
                }
            }
        }
    }
    return result
}

/**
 * Sum time series
 */
fun sumTimeSeries(left: Map<Long, BigInteger>, right: Map<Long, BigInteger>): Map<Long, BigInteger> {
    if (left.isEmpty() && right.isEmpty()) {
        return mapOf()
    }
    if (left.isEmpty() && right.isNotEmpty()) {
        return right
    }
    if (left.isNotEmpty() && right.isEmpty()) {
        return left
    }
    val result = mutableMapOf<Long, BigInteger>()
    left.forEach { (k, v) -> result[k] = v }
    right.forEach { (k, v) -> result.compute(k) { _, oldValue ->
        if (oldValue == null) {
            v
        } else {
            oldValue.add(v)
        }
    } }
    return result
}

/**
 * Accumulate time series
 */
fun accumulateTimeSeries(accumulator: MutableMap<Long, BigInteger>, value: Map<Long, BigInteger>) {
    value.forEach { (k, v) -> accumulator.compute(k) { _, oldValue ->
        if (oldValue == null) {
            v
        } else {
            oldValue.add(v)
        }
    } }
}

/**
 * Round to integer, half up rounding
 */
fun roundToIntegerHalfUp(value: BigDecimal): BigInteger {
    return value.setScale(0, RoundingMode.HALF_UP).toBigInteger()
}

/**
 * Sample mean
 */
fun mean(values: Collection<BigInteger>): BigDecimal {
    if (values.isEmpty()) {
        return BigDecimal.ZERO
    }
    return values.sumOf { it }.toBigDecimal().divide(BigDecimal.valueOf(values.size.toLong()), 34, RoundingMode.HALF_UP)
}

/**
 * Unbiased sample variance
 */
fun variance(values: Collection<BigInteger>, mean: BigDecimal): BigDecimal {
    if (values.size <= 1) {
        return BigDecimal.ZERO
    }
    return values.asSequence().map { it.toBigDecimal() }.sumOf { it.subtract(mean).pow(2) }
        .divide(BigDecimal.valueOf(values.size.toLong() - 1), 34, RoundingMode.HALF_UP)
}

/**
 * Standard deviation, rounded to integer
 */
fun standardDeviation(variance: BigInteger): BigInteger {
    return roundToIntegerHalfUp(variance.toBigDecimal().sqrt(MathContext(34, RoundingMode.HALF_UP)))
}

/**
 * Sample median, min, max
 */
fun minMedianMax(values: Collection<BigInteger>): Triple<BigInteger, BigDecimal, BigInteger> {
    if (values.isEmpty()) {
        return Triple(BigInteger.ZERO, BigDecimal.ZERO, BigInteger.ZERO)
    }
    if (values.size == 1) {
        return Triple(values.first(), values.first().toBigDecimal(), values.first())
    }
    val sorted = values.sorted()
    val min = sorted.first()
    val max = sorted.last()
    return if (sorted.size.toLong() % 2L != 0L) {
        val index = ((sorted.size.toLong() + 1L) / 2L).toInt() - 1
        Triple(min, sorted[index].toBigDecimal(), max)
    } else {
        val index = sorted.size / 2
        Triple(min, sorted[index - 1].add(sorted[index]).toBigDecimal()
            .divide(BigDecimal.valueOf(2L), 34, RoundingMode.HALF_UP), max)
    }
}

/**
 * Make histogram
 */
fun histogram(values: Collection<BigInteger>, min: BigInteger, max: BigInteger): List<HistogramBin> {
    if (values.isEmpty()) {
        return listOf()
    }
    if (values.size == 1) {
        return listOf(HistogramBin(values.first(), values.first(), 1L))
    }
    if (min.compareTo(max) == 0) {
        return listOf(HistogramBin(min, max, values.size.toLong()))
    }
    val estimatedBinsCount = ceil(log2(values.size.toDouble())).toLong() + 1L
    val binsCount = if (max.subtract(min).plus(BigInteger.ONE) < BigInteger.valueOf(estimatedBinsCount)) {
        max.subtract(min).plus(BigInteger.ONE).toLong()
    } else {
        estimatedBinsCount
    }
    val binWidth = max.subtract(min).toBigDecimal()
        .divide(BigDecimal.valueOf(binsCount), 0, RoundingMode.UP).toBigInteger()
    val binCounters = mutableMapOf<Long, Long>()
    values.forEach{ value ->
        val binIndex = value.subtract(min).divide(binWidth).toLong()
        val correctedBinIndex = if (binIndex == binsCount) {
            binIndex - 1
        } else {
            binIndex
        }
        binCounters.compute(correctedBinIndex) { _, oldValue ->
            if (oldValue == null) {
                1L
            } else {
                oldValue + 1L
            }
        }
    }
    return binCounters.map { entry ->
        HistogramBin(min.add(BigInteger.valueOf(entry.key).multiply(binWidth)),
            min.add(BigInteger.valueOf(entry.key + 1L).multiply(binWidth)), entry.value)
    }.sortedBy { it.to }
}

/**
 * Relative usage, percents, two decimal places precision
 */
fun relativeUsage(accumulated: BigInteger, provision: BigInteger, duration: Long): Double? {
    if (provision.compareTo(BigInteger.ZERO) == 0 || accumulated < BigInteger.ZERO) {
        return null
    }
    if (accumulated.compareTo(BigInteger.ZERO) == 0 || duration == 0L) {
        return 0.0
    }
    val availableUsage = provision.multiply(BigInteger.valueOf(duration)).toBigDecimal()
    return accumulated.multiply(BigInteger.valueOf(100)).toBigDecimal()
        .divide(availableUsage, 2, RoundingMode.HALF_UP).toDouble()
}

/**
 * Under-utilization
 */
fun underutilized(accumulated: BigInteger, provision: BigInteger, duration: Long): BigInteger {
    if (duration == 0L) {
        return provision
    }
    return provision.multiply(BigInteger.valueOf(duration)).subtract(accumulated).toBigDecimal()
        .divide(BigDecimal.valueOf(duration),0, RoundingMode.HALF_UP).toBigInteger()
}

/**
 * Variation coefficient
 */
fun variationCoefficient(standardDeviation: BigInteger, mean: BigInteger): Double? {
    if (mean <= BigInteger.ZERO) {
        return null
    }
    return standardDeviation.multiply(BigInteger.valueOf(100)).toBigDecimal()
        .divide(mean.toBigDecimal(), 2, RoundingMode.HALF_UP).toDouble()
}

/**
 * Numerical integration of the time series, time is in seconds.
 * Piecewise nature of the time series is taken into account, grid step is used to find missing samples.
 * Time is also integrated piecewise.
 */
fun accumulate(timeSeries: Map<Long, BigInteger>, gridStep: Long): Pair<BigDecimal, Long> {
    if (timeSeries.isEmpty()) {
        return Pair(BigDecimal.ZERO, 0L)
    }
    if (timeSeries.size == 1) {
        return Pair(timeSeries.values.first().multiply(BigInteger.valueOf(gridStep)).toBigDecimal(), gridStep)
    }
    val sortedTimestamps = timeSeries.keys.sorted()
    var subListStart = 0;
    var accumulated = BigDecimal.ZERO
    var accumulatedDuration = 0L
    for (i in sortedTimestamps.indices) {
        if (i == sortedTimestamps.size - 1) {
            val subAccumulated = accumulateSubArray(timeSeries, sortedTimestamps.subList(subListStart, i + 1), gridStep)
            accumulated = accumulated.add(subAccumulated.first)
            accumulatedDuration += subAccumulated.second
        } else {
            val currentTimestamp = sortedTimestamps[i]
            val actualNextTimestamp = sortedTimestamps[i + 1]
            val nextTimestampThreshold = currentTimestamp + gridStep * 2
            if (actualNextTimestamp >= nextTimestampThreshold) {
                val subAccumulated = accumulateSubArray(timeSeries, sortedTimestamps.subList(subListStart, i + 1), gridStep)
                subListStart = i + 1
                accumulated = accumulated.add(subAccumulated.first)
                accumulatedDuration += subAccumulated.second
            }
        }
    }
    return Pair(accumulated, accumulatedDuration)
}

private fun accumulateSubArray(timeSeries: Map<Long, BigInteger>, timestamps: List<Long>,
                               gridStep: Long): Pair<BigDecimal, Long> {
    if (timestamps.isEmpty()) {
        return Pair(BigDecimal.ZERO, 0)
    }
    if (timestamps.size == 1) {
        return Pair(timeSeries[timestamps[0]]!!.multiply(BigInteger.valueOf(gridStep)).toBigDecimal(), gridStep)
    }
    var accumulated = BigDecimal.ZERO
    var accumulatedDuration = 0L
    for (i in 1 until timestamps.size) {
        val dx = timestamps[i] - timestamps[i - 1]
        accumulatedDuration += dx
        val sum = timeSeries[timestamps[i]]!!.add(timeSeries[timestamps[i - 1]]!!).multiply(BigInteger.valueOf(dx)).toBigDecimal()
            .divide(BigDecimal.valueOf(2), 34, RoundingMode.HALF_UP)
        accumulated = accumulated.add(sum)
    }
    return Pair(accumulated, accumulatedDuration)
}
