package ru.yandex.tours.util

import java.util.concurrent.ConcurrentHashMap

import com.codahale.metrics.MetricRegistry
import com.codahale.metrics.health.HealthCheck
import org.joda.time.{DateTime, Duration}
import ru.yandex.common.monitoring.error._
import ru.yandex.common.monitoring.{CompoundHealthCheckRegistry, HealthChecks, LastErrorChecks, TimeMarker}
import ru.yandex.tours.util.monitoring.StateMonitoring

import scala.concurrent.duration.FiniteDuration
import scala.reflect._

/**
 * Author: Vladislav Dolbilov (darl@yandex-team.ru)
 * Created: 25.03.15
 */
class Monitorings(registry: CompoundHealthCheckRegistry, group: String, devOnly: Boolean = false) {

  private def lastErrorChecks = if (devOnly) Monitorings.lastErrorDevChecks else Monitorings.lastErrorChecks
  import LastErrorChecks.{DEFAULT_ERROR_DELAY, DEFAULT_WARNING_DELAY, DEFAULT_ERROR_MAX_SILENCE, DEFAULT_WARNING_MAX_SILENCE}

  private def registerCheck[T <: HealthCheck : ClassTag](name: String, healthCheck: => T): T = {
    val old = Monitorings.healthChecks.get(name)
    if (old ne null) {
      return Monitorings.cast[T](name, old)
    }
    val check = healthCheck
    Monitorings.healthChecks.putIfAbsent(name, check) match {
      case null =>
        if (devOnly) {
          Monitorings.registry.registerDeveloper(name, check)
        } else {
          Monitorings.registry.register(name, check)
        }
        check
      case hc => Monitorings.cast[T](name, hc)
    }
  }

  private def registerRef[T <: AnyRef : ClassTag](name: String, ref: => T): T = {
    val old = Monitorings.refs.get(name)
    if (old ne null) {
      return Monitorings.cast[T](name, old)
    }
    Monitorings.refs.synchronized {
      Monitorings.refs.get(name) match {
        case null =>
          val check = ref
          Monitorings.refs.put(name, check)
          check
        case old2 => Monitorings.cast[T](name, old2)
      }
    }
  }

  private def buildName(name: String) = MetricRegistry.name(group, name)

  def lastError(name: String, lastError: () => DateTime,
                         errorDelay: Duration = DEFAULT_ERROR_DELAY,
                         warningDelay: Duration = DEFAULT_WARNING_DELAY): Unit =
    lastErrorChecks.lastError(buildName(name), lastError, errorDelay, warningDelay)

  def lastErrorWithMarker(name: String,
                                   errorDelay: Duration = DEFAULT_ERROR_DELAY,
                                   warningDelay: Duration = DEFAULT_WARNING_DELAY): TimeMarker =
    registerRef(buildName(name), lastErrorChecks.lastErrorWithMarker(buildName(name), errorDelay, warningDelay))

  def lastErrorAsWarning(name: String, lastError: () => DateTime,
                                  warningDelay: Duration = DEFAULT_WARNING_DELAY): Unit =
    lastErrorChecks.lastErrorAsWarning(buildName(name), lastError, warningDelay)

  def lastErrorAsWarningWithMarker(name: String,
                                            warningDelay: Duration = DEFAULT_WARNING_DELAY): TimeMarker =
    registerRef(buildName(name), lastErrorChecks.lastErrorAsWarningWithMarker(buildName(name), warningDelay))

  def lastEvent(name: String,
                         lastEvent: () => DateTime,
                         warningMaxSilence: Duration = DEFAULT_WARNING_MAX_SILENCE,
                         errorMaxSilence: Duration = DEFAULT_ERROR_MAX_SILENCE): Unit =
    lastErrorChecks.lastEvent(buildName(name), lastEvent, warningMaxSilence, errorMaxSilence)

  def lastEventWithMarker(name: String,
                                   warningMaxSilence: Duration = DEFAULT_WARNING_MAX_SILENCE,
                                   errorMaxSilence: Duration = DEFAULT_ERROR_MAX_SILENCE): TimeMarker =
    registerRef(buildName(name), lastErrorChecks.lastEventWithMarker(buildName(name), warningMaxSilence, errorMaxSilence))

  def lastEventAsWarning(name: String,
                                  lastEvent: () => DateTime,
                                  warningMaxSilence: Duration = DEFAULT_WARNING_MAX_SILENCE): Unit =
    lastErrorChecks.lastEventAsWarning(buildName(name), lastEvent, warningMaxSilence)

  def lastEventAsWarningWithMarker(name: String,
                                            warningMaxSilence: Duration = DEFAULT_WARNING_MAX_SILENCE): TimeMarker =
    registerRef(buildName(name), lastErrorChecks.lastEventAsWarningWithMarker(buildName(name), warningMaxSilence))

  def errorCount(name: String, maxErrors: Int, size: Int = 1000): ErrorReservoir = {
    ErrorReservoirs.register(buildName(name),
      new AlwaysWarningErrorCounterReservoir(maxErrors, size),
      alsoAsOperational = !devOnly,
      registry
    )
  }

  def errorCount(name: String, warningCount: Int, maxCount: Int, expire: FiniteDuration, size: Int): ErrorReservoir = {
    ErrorReservoirs.register(buildName(name),
      new ExpiringWarningErrorCounterReservoir(warningCount, maxCount, expire, size),
      alsoAsOperational = !devOnly,
      registry
    )
  }

  def errorRate(name: String, warningRate: Double, maxRate: Double, size: Int = 1000, minErrorCount: Int = 1): ErrorReservoir = {
    ErrorReservoirs.register(buildName(name),
      new ExpiringPercentileReservoir((warningRate * 100).toInt, (maxRate * 100).toInt, windowSize = size, minErrorCount = minErrorCount),
      alsoAsOperational = !devOnly,
      registry
    )
  }

  def errorRate(name: String, maxRate: Double, window: FiniteDuration): ErrorReservoir = {
    ErrorReservoirs.register(buildName(name), new AlwaysWarningErrorPercentileTimeWindowReservoir((maxRate * 100).toInt, window), registry)
  }

  def state(name: String): StateMonitoring =
    registerCheck(buildName(name), new StateMonitoring(buildName(name)))

  def apply(name: String = "", devOnly: Boolean = devOnly) = new Monitorings(registry, buildName(name), devOnly)
}

object Monitorings extends Monitorings(HealthChecks.compoundRegistry(), "", false) {
  val registry = HealthChecks.compoundRegistry()

  private val healthChecks = new ConcurrentHashMap[String, HealthCheck]()
  private val refs = new ConcurrentHashMap[String, AnyRef]()
  private val lastErrorChecks = new LastErrorChecks(registry)
  private val lastErrorDevChecks = new LastErrorChecks(registry.getDeveloperChecks)

  private def cast[T <: AnyRef : ClassTag](name: String, hc: AnyRef): T = {
    if (classTag[T].runtimeClass.isInstance(hc)) {
      hc.asInstanceOf[T]
    } else {
      throw new RuntimeException(s"HealthCheck with different type already registered. Name = $name, Class = ${hc.getClass}")
    }
  }
}