package ru.yandex.tours.app

import java.time.Duration
import java.util
import java.util.Optional
import java.util.concurrent.CompletableFuture

import ch.qos.logback.classic.{Level, Logger => LogbackLogger}
import com.codahale.metrics.{Gauge, Meter, Timer}
import com.typesafe.config.Config
import org.slf4j.LoggerFactory
import ru.yandex.tours.util.{Logging, Metrics, UpdatableLongGauge}
import ru.yandex.travel.yt.queries.QueryPart
import ru.yandex.travel.yt.{Factory, ReplicationInfo, YtDao}
import ru.yandex.tours.util.lang.Futures._
import ru.yandex.travel.yt.daos.{MirroredYtDao, SingleYtDao}

import collection.JavaConverters._
import scala.collection.mutable


case class YtSettings(enabled: Boolean,
                      factory: Factory,
                      primaryCluster: String,
                      secondaryCluster: String,
                      basePath: String,
                      cacheDirections: Boolean,
                      serpCacheClusters: Array[String]
                     )

trait YtSupport {
  this: Application with DefaultEnvironment =>

  try {
    LoggerFactory.getLogger("ru.yandex.yt.ytclient.rpc.BalancingRpcClient").asInstanceOf[LogbackLogger].setLevel(Level.toLevel("WARN"))
  }
  catch {
    case e: Throwable => log.warn("Unable to change YT logger level", e)
  }

  private val ytConfig: Config = config.getConfig("tours.yt")

  val ytSettings: YtSettings = YtSettings(
    ytConfig.getBoolean("enabled"),
    buildFactory(),
    ytConfig.getString("primaryMaster"),
    ytConfig.getString("secondaryMaster"),
    ytConfig.getString("baseDir"),
    if (ytConfig.hasPath("cacheDirections")) ytConfig.getBoolean("cacheDirections") else false,
    if (ytConfig.hasPath("serpCacheClusters")) ytConfig.getString("serpCacheClusters").split(",")
    else new Array[String](0)
  )

  private def buildFactory() = {
    val clusters = ytConfig.getObject("clusters")
    val clusterMap = for (clusterName <- clusters.keySet.asScala)
      yield {
        val clusterObj = clusters.toConfig.getObject(clusterName)
        clusterName -> {
          val cluster = for (dc <- clusterObj.keySet.asScala) yield {
            dc -> clusterObj.toConfig.getString(dc).split(",").toSeq.asJava
          }
          cluster.toMap.asJava
        }
      }
    val f = new Factory(
      clusterMap.toMap.asJava,
      ytConfig.getInt("port"),
      ytConfig.getString("user"),
      ytConfig.getString("token"),
      this.dataCenter)
    if (ytConfig.hasPath("localTimeout")) f.setLocalTimeout(Duration.ofMillis(ytConfig.getLong("localTimeout")))
    if (ytConfig.hasPath("globalTimeout")) f.setGlobalTimeout(Duration.ofMillis(ytConfig.getLong("globalTimeout")))
    if (ytConfig.hasPath("pingTimeout")) f.setPingTimeout(Duration.ofMillis(ytConfig.getLong("pingTimeout")))
    if (ytConfig.hasPath("clientTimeout")) f.setClientTimeout(Duration.ofMillis(ytConfig.getLong("clientTimeout")))
    f
  }

  def createMirroredYtDao[T](daoClass: Class[T], primaryCluster: String, secondaryCluster: String,
                             tableName: String = null, ttl: Long = -1L): YtDao[T] = {
    val prim = createYtDao(daoClass, primaryCluster, tableName, ttl)
    val sec = createYtDao(daoClass, secondaryCluster, tableName, ttl)
    new MirroredYtDao[T](prim, sec) with InstrumentedMirroredYtDao[T]
  }

  def createYtDao[T](daoClass: Class[T], cluster: String, tableName: String = null, ttl: Long = -1L): YtDao[T] = {
    new SingleYtDao[T](ytSettings.factory,
      cluster,
      ytSettings.basePath,
      daoClass,
      tableName,
      ttl) with InstrumentedSingleYtDao[T]
  }
}

trait InstrumentedSingleYtDao[T] extends SingleYtDao[T] with Logging {
  private val metrics = Metrics(s"storage.yt.${getClusterName}.${getTableName}")
  private val selectTimer = metrics.getTimer("select")
  private val selectError = metrics.getMeter("select-error")
  private val deleteTimer = metrics.getTimer("delete")
  private val deleteError = metrics.getMeter("delete-error")
  private val putTimer = metrics.getTimer("put")
  private val putError = metrics.getMeter("put-error")
  private val putManyTimer = metrics.getTimer("putMany")
  private val putManyError = metrics.getMeter("putMany-error")
  private val getTimer = metrics.getTimer("get")
  private val getError = metrics.getMeter("get-error")

  override def select(queryParts: util.List[QueryPart]): CompletableFuture[util.List[T]] = {
    val ctx = selectTimer.time
    super.select(queryParts).asScalaFuture
      .withTimerContext(ctx)
      .withErrorRateMeter(selectError)
      .asCompletableFuture
  }

  override def select(queryParts: QueryPart*): CompletableFuture[util.List[T]] = {
    val ctx = selectTimer.time
    super.select(queryParts.toList.asJava).asScalaFuture
      .withTimerContext(ctx)
      .withErrorRateMeter(selectError)
      .asCompletableFuture
  }

  override def put(obj: T): CompletableFuture[Void] = {
    val ctx = putTimer.time
    super.put(obj).asScalaFuture
      .withTimerContext(ctx)
      .withErrorRateMeter(putError)
      .asCompletableFuture
  }

  override def put(objects: util.List[T]): CompletableFuture[Void] = {
    val ctx = putManyTimer.time
    super.put(objects).asScalaFuture
      .withTimerContext(ctx)
      .withErrorRateMeter(putManyError)
      .asCompletableFuture
  }

  override def delete(keyValues: util.List[AnyRef]): CompletableFuture[Void] = {
    val ctx = deleteTimer.time
    super.delete(keyValues).asScalaFuture
      .withTimerContext(ctx)
      .withErrorRateMeter(deleteError)
      .asCompletableFuture
  }

  override def delete(keyValues: AnyRef*): CompletableFuture[Void] = {
    val ctx = deleteTimer.time
    super.delete(keyValues.toList.asJava).asScalaFuture
      .withTimerContext(ctx)
      .withErrorRateMeter(deleteError)
      .asCompletableFuture
  }

  override def get(keyValues: util.List[AnyRef]): CompletableFuture[Optional[T]] = {
    val ctx = getTimer.time
    super.get(keyValues).asScalaFuture
      .withTimerContext(ctx)
      .withErrorRateMeter(getError)
      .asCompletableFuture
  }

  override def get(keyValues: AnyRef*): CompletableFuture[Optional[T]] = {
    val ctx = getTimer.time
    super.get(keyValues.toList.asJava).asScalaFuture
      .withTimerContext(ctx)
      .withErrorRateMeter(getError)
      .asCompletableFuture
  }
}

trait InstrumentedMirroredYtDao[T] extends MirroredYtDao[T] with Logging {
  private val primaryMetrics = Metrics(s"storage.yt-mirrored.primary.${getTableName}")
  private val secondryMetrics = Metrics(s"storage.yt-mirrored.secondary.${getTableName}")

  private def getMetricsSet(name: String): (Timer, Timer, Meter, Meter) = {
    (primaryMetrics.getTimer(name), secondryMetrics.getTimer(name),
      primaryMetrics.getMeter(s"$name-error"), secondryMetrics.getMeter(s"$name-error"))
  }

  private val (primSelectTimer, secSelectTimer, primSelectError, secSelectError) = getMetricsSet("select")
  private val (primDeleteTimer, secDeleteTimer, primDeleteError, secDeleteError) = getMetricsSet("delete")
  private val (primPutTimer, secPutTimer, primPutError, secPutError) = getMetricsSet("put")
  private val (primPutManyTimer, secPutManyTimer, primPutManyError, secPutManyError) = getMetricsSet("putMany")
  private val (primGetTimer, secGetTimer, primGetError, secGetError) = getMetricsSet("get")

  override def shouldUseSecondary(): Boolean = super.shouldUseSecondary()

  override def select(queryParts: util.List[QueryPart]): CompletableFuture[util.List[T]] = {
    val selectTimer = if (!shouldUseSecondary()) primSelectTimer else secSelectTimer
    val selectError = if (!shouldUseSecondary()) primSelectError else secSelectError
    val ctx = selectTimer.time
    super.select(queryParts).asScalaFuture
      .withTimerContext(ctx)
      .withErrorRateMeter(selectError)
      .asCompletableFuture
  }

  override def select(queryParts: QueryPart*): CompletableFuture[util.List[T]] = {
    val selectTimer = if (!shouldUseSecondary()) primSelectTimer else secSelectTimer
    val selectError = if (!shouldUseSecondary()) primSelectError else secSelectError
    val ctx = selectTimer.time
    super.select(queryParts.toList.asJava).asScalaFuture
      .withTimerContext(ctx)
      .withErrorRateMeter(selectError)
      .asCompletableFuture
  }

  override def put(obj: T): CompletableFuture[Void] = {
    val putTimer = if (!shouldUseSecondary()) primPutTimer else secPutTimer
    val putError = if (!shouldUseSecondary()) primPutError else secPutError
    val ctx = putTimer.time
    super.put(obj).asScalaFuture
      .withTimerContext(ctx)
      .withErrorRateMeter(putError)
      .asCompletableFuture
  }

  override def put(objects: util.List[T]): CompletableFuture[Void] = {
    val putManyTimer = if (!shouldUseSecondary()) primPutManyTimer else secPutManyTimer
    val putManyError = if (!shouldUseSecondary()) primPutManyError else secPutManyError
    val ctx = putManyTimer.time
    super.put(objects).asScalaFuture
      .withTimerContext(ctx)
      .withErrorRateMeter(putManyError)
      .asCompletableFuture
  }

  override def delete(keyValues: util.List[AnyRef]): CompletableFuture[Void] = {
    val deleteTimer = if (!shouldUseSecondary()) primDeleteTimer else secDeleteTimer
    val deleteError = if (!shouldUseSecondary()) primDeleteError else secDeleteError
    val ctx = deleteTimer.time
    super.delete(keyValues).asScalaFuture
      .withTimerContext(ctx)
      .withErrorRateMeter(deleteError)
      .asCompletableFuture
  }

  override def delete(keyValues: AnyRef*): CompletableFuture[Void] = {
    val deleteTimer = if (!shouldUseSecondary()) primDeleteTimer else secDeleteTimer
    val deleteError = if (!shouldUseSecondary()) primDeleteError else secDeleteError
    val ctx = deleteTimer.time
    super.delete(keyValues.toList.asJava).asScalaFuture
      .withTimerContext(ctx)
      .withErrorRateMeter(deleteError)
      .asCompletableFuture
  }

  override def get(keyValues: util.List[AnyRef]): CompletableFuture[Optional[T]] = {
    val getTimer = if (!shouldUseSecondary()) primGetTimer else secGetTimer
    val getError = if (!shouldUseSecondary()) primGetError else secGetError
    val ctx = getTimer.time
    super.get(keyValues).asScalaFuture
      .withTimerContext(ctx)
      .withErrorRateMeter(getError)
      .asCompletableFuture
  }

  override def get(keyValues: AnyRef*): CompletableFuture[Optional[T]] = {
    val getTimer = if (!shouldUseSecondary()) primGetTimer else secGetTimer
    val getError = if (!shouldUseSecondary()) primGetError else secGetError
    val ctx = getTimer.time
    super.get(keyValues.toList.asJava).asScalaFuture
      .withTimerContext(ctx)
      .withErrorRateMeter(getError)
      .asCompletableFuture
  }
}