package ru.yandex.tours.backend.search

import ru.yandex.tours.model.Source
import ru.yandex.tours.model.search.SearchResults.SearchProgress
import ru.yandex.tours.partners.PartnerProtocol._
import ru.yandex.tours.util.collections.SimpleBitSet

import scala.collection.JavaConversions._
import scala.collection.mutable

class SearcherProgress[I <: SearchResult[_]](allSources: Set[_ <: Source]) {
  // Mapping from source to number of expected requests
  private val source2wait = mutable.HashMap.empty[Source, Int]
  private val source2results = mutable.HashMap.empty[Source, mutable.Buffer[I]]
    .withDefaultValue(mutable.Buffer.empty)


  def add(searchResult: I): Unit = {
    val results = source2results.getOrElseUpdate(searchResult.searchSource, mutable.Buffer.empty)
    results.append(searchResult)
  }

  def isFinished: Boolean = allSources.forall(isSourceFinished)

  def getResults: Iterable[I] = source2results.flatMap(_._2)

  def currentProgress: SearchProgress = {
    val (done, skipped, failed) = collectStatistic
    SearchProgress.newBuilder()
      .setOperatorTotalCount(allSources.size)
      .setOperatorCompleteCount(done.size)
      .setOperatorFailedCount(failed.size)
      .setOperatorSkippedCount(skipped.size)
      .setIsFinished(isFinished)
      .addAllOBSOLETEFailedOperators(asJavaIterable(failed.map(_.id).map(Int.box)))
      .setOperatorCompleteSet(SimpleBitSet(done.map(_.id)).packed)
      .setOperatorFailedSet(SimpleBitSet(failed.map(_.id)).packed)
      .setOperatorSkippedSet(SimpleBitSet(skipped.map(_.id)).packed)
      .build()
  }

  def updateWaitMap(map: Map[Source, Int]): Unit = {
    map.foreach {
      case (source, count) => source2wait.put(source, source2wait.getOrElse(source, 0) + count)
    }
  }

  private def collectStatistic = (collectDoneSources, collectSkippedSources, collectFailedSources)

  private def collectStatisticForSource(fits: Results[_] => Boolean) = {
    source2results.filter {
      case (source, results) => isSourceFinished(source) && results.map(_.result).forall(fits)
    }.keys
  }

  def isSourceFinished(source: Source): Boolean = {
    val done = source2results(source).map(_.result).count {
      case Partial(_) => false
      case _ => true
    }
    val wait = source2wait.get(source)
    wait.contains(done)
  }

  def getSourcesWithResult: collection.Set[Source] = source2results.keySet

  private def collectFailedSources = collectStatisticForSource({
    case Failed(_) => true
    case _ => false
  })

  private def collectSkippedSources = collectStatisticForSource({
    case Skipped => true
    case _ => false
  })

  private def collectDoneSources = collectStatisticForSource(_ => true)

  override def toString: String = {
    val sb = new StringBuilder("SearcherProgress:\n")
    sb append "  isFinished = " append isFinished append "\n"
    if (!isFinished) {
      allSources.filterNot(isSourceFinished).addString(sb, "  waitingFor = ", ", ", "\n")
      source2wait.addString(sb, "  waitMap = ", ", ", "\n")
    }
    collectDoneSources.addString(sb, "  done = ", ", ", "\n")
    collectSkippedSources.addString(sb, "  skipped = ", ", ", "\n")
    collectFailedSources.addString(sb, "  failed = ", ", ", "\n")
    sb.toString()
  }
}
