package ru.yandex.tours.indexer.clusterization

import akka.actor.ActorSystem
import ru.yandex.tours.clustering.Clustering.LinkWithConfidence
import ru.yandex.tours.db._
import ru.yandex.tours.db.dao.HotelsDao
import ru.yandex.tours.db.dao.HotelsDao.{SkipDeleted, WithIds}
import ru.yandex.tours.db.tables.{ClusterLink, Clusterization, HotelCoordinates, Hotels}
import ru.yandex.tours.geo
import ru.yandex.tours.hotels.HotelsIndex
import ru.yandex.tours.hotels.clustering.{ClusteringContext, ClusteringModel, HotelContext, LocalContext}
import ru.yandex.tours.indexer.task.{AsyncUpdatable, TaskWeight}
import ru.yandex.tours.model.BaseModel.Point
import ru.yandex.tours.model.hotels.HotelsHolder.PartnerHotel
import ru.yandex.tours.model.hotels.Partners.Partner
import ru.yandex.tours.util.Collections._
import ru.yandex.tours.util.collections.RafBasedMap
import ru.yandex.tours.util.concurrent.{AsyncWorkQueue, BatchExecutor}
import ru.yandex.tours.util.lang.Futures._
import slick.driver.MySQLDriver.api._

import scala.collection.parallel.ParSeq
import scala.concurrent.duration.FiniteDuration
import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Failure, Random}

/**
 * Author: Vladislav Dolbilov (darl@yandex-team.ru)
 * Created: 24.05.16
 */
class NewHotelClusterizer(dbWrapper: DBWrapper,
                          hotelsDao: HotelsDao,
                          clusteringModel: ClusteringModel,
                          updateTime: FiniteDuration)
                         (implicit akka: ActorSystem, ec: ExecutionContext)
  extends AsyncUpdatable(updateTime, "cluster_hotels") with TaskWeight.Unique {

  protected val partnerFilter: Option[Partner] = None

  private val MaxDistance = 3d
  private val BatchSize = 1000
  private val ConfidenceThreshold = 0.5
  private val Parallelism = 4

  case class HotelRef(id: Int, isNew: Boolean, partner: Partner, point: Point) {
    def gridPoint: GridPoint = Grid.getGeoPoint(point)
    def distanceTo(ref: HotelRef): Double = geo.distanceInKm(point, ref.point)
  }

  private def retrieveData(): Future[Seq[HotelRef]] = {
    val q = Hotels.table
      .join(HotelCoordinates.table).on(_.id === _.id)
      .join(HotelStableIds.table).on(_._1.id === _.id)
      .map { case ((hotel, coord), stableIds) ⇒ (hotel.id, hotel.isNew, stableIds.partnerObj, coord.point) }

    dbWrapper.run(q.result.transactionally)
      .map { hotels ⇒ hotels.map((HotelRef.apply _).tupled) }
      .map { refs ⇒ refs.filterNot(ref ⇒ HotelsIndex.isEmptyPoint(ref.point)) }
      .logTiming("HotelRef retrieving")
  }


  protected def buildHotelContexts(hotelsToBuild: Seq[HotelRef],
                                   localContext: LocalContext,
                                   hotelsMap: collection.Map[Int, PartnerHotel]): Map[Int, HotelContext] = {
    hotelsToBuild
      .flatMap(ref ⇒ hotelsMap.get(ref.id))
      .par
      .map(hotel ⇒ hotel.getId → HotelContext(hotel, localContext))
      .seq
      .toMap
  }

  protected def generateLinks(hotelsToCluster: Seq[HotelRef],
                              candidates: Seq[HotelRef],
                              contexts: Map[Int, HotelContext],
                              metrics: ClusteringMetrics): ParSeq[LinkWithConfidence] = {
    for {
      ref1 ← hotelsToCluster.par
      ref2 ← candidates.par
      if ref1.id != ref2.id
      if ref2.distanceTo(ref1) <= MaxDistance
      ctx1 ← contexts.get(ref1.id)
      ctx2 ← contexts.get(ref2.id)
      conf = {
        metrics.addPair()
        clusteringModel.apply(ClusteringContext(ctx1, ctx2))
      }
      if conf >= ConfidenceThreshold
    } yield {
      LinkWithConfidence(ref1.id min ref2.id, ref1.id max ref2.id, conf)
    }
  }

  protected def shouldPublish(id: Int, linkConf: Map[Int, List[Double]]): Boolean = {
    val confs = linkConf.getOrElse(id, Seq.empty)
    confs.isEmpty || confs.exists(_ >= Clusterization.defaultMinConfidence)
  }

  protected def saveLinksAndPublish(hotelsToPublish: Seq[HotelRef],
                                    links: Iterable[LinkWithConfidence],
                                    transaction: Transaction,
                                    metrics: ClusteringMetrics): Future[Unit] = {
    val linksToAdd = links
      .filter(_.confidence >= Clusterization.defaultMinConfidence)
      .map { l ⇒ ClusterLink(0, l.parent, l.child, transaction.id, l.confidence) }
    val linkConf = links.flatMap(l ⇒ Seq(l.parent → l.confidence, l.child → l.confidence)).toMultiMap
    val (idsToPublish, ignored) = hotelsToPublish.map(_.id).toSet.partition(shouldPublish(_, linkConf))

    metrics.addLinks(linksToAdd.size)
    metrics.addHotels(idsToPublish.size, ignored.size)

    for {
      linksCount <- BatchExecutor.executeInBatch[ClusterLink](linksToAdd.iterator,
        "cluster links inserted",
        BatchSize,
        tables.Clusterization.insert(dbWrapper, _))
      hotelsAdded <- BatchExecutor.executeInBatch[Int](idsToPublish.iterator,
        "hotels published",
        BatchSize,
        hotelsDao.publish)
    } yield {
      log.info(s"$linksCount links added to db. $hotelsAdded hotels published.")
    }
  }

  protected def getHotelMap(ids: Set[Int]): Future[collection.Map[Int, PartnerHotel]] = {
    if (ids.size > 1000) {
      hotelsDao.retrieveRafMap(WithIds(ids), SkipDeleted)
    } else {
      hotelsDao.get(ids).map { hotels ⇒ hotels.filterNot(_.isDeleted).map(h ⇒ h.id → h.hotel).toMap }
    }
  }

  protected def doClustering(hotelsToCluster: Seq[HotelRef],
                             candidates: Seq[HotelRef],
                             cleaner: HotelLinkCleaner,
                             metrics: ClusteringMetrics,
                             transaction: Transaction) = {

    val map = getHotelMap(candidates.map(_.id).toSet ++ hotelsToCluster.map(_.id).toSet)

    val links: Future[Seq[LinkWithConfidence]] = {
      for (hotelsMap ← map) yield {
        val localContext = new LocalContext(hotelsMap.valuesIterator)

        (for {
          groupedCandidates ← candidates.grouped(1000)
          candidateContexts = buildHotelContexts(groupedCandidates, localContext, hotelsMap)
          hotelsToCluster ← hotelsToCluster.grouped(100)
          contexts = buildHotelContexts(hotelsToCluster, localContext, hotelsMap)
          link ← generateLinks(hotelsToCluster, groupedCandidates, candidateContexts ++ contexts, metrics).seq
        } yield link).toVector
      }
    }
    links.onComplete { _ ⇒
      map.foreach {
        case m: RafBasedMap[_, _] ⇒ m.close()
        case _ ⇒
      }
    }

    for {
      links ← links
      cleaned = cleaner.removeExcessLinksAndSetConfidence(links)
      _ ← saveLinksAndPublish(hotelsToCluster, cleaned, transaction, metrics)
    } yield metrics.doLogging()
  }

  private def cluster(point: GridPoint,
                      gridMap: Map[GridPoint, Seq[HotelRef]],
                      cleaner: HotelLinkCleaner,
                      metrics: ClusteringMetrics,
                      transaction: Transaction): Future[Unit] = {

    var newHotels = gridMap.getOrElse(point, Seq.empty).filter(_.isNew)
    if (partnerFilter.isDefined) {
      newHotels = newHotels.filter(_.partner == partnerFilter.get)
    }
    if (newHotels.nonEmpty) {
      val candidates = for {
        point ← point.nearPoints
        hotel ← gridMap.getOrElse(point, Seq.empty)
      } yield hotel

      doClustering(newHotels, candidates, cleaner, metrics, transaction)
        .logTiming(s"Clustering at $point, new hotels = ${newHotels.size}, candidates = ${candidates.size}")
    } else {
      Future.successful(())
    }
  }


  override protected def update: Future[_] = {
    val queue = new AsyncWorkQueue(Parallelism)(akka, ec)
    retrieveData().flatMap { allHotelRefs ⇒
      val gridMap = allHotelRefs.groupBy(_.gridPoint)
      val metrics = new ClusteringMetrics

      for {
        transaction ← Transactions.newTransaction(dbWrapper)
        oldLinks ← Clusterization.retrieveClusterLinks(Clusterization.defaultMinConfidence, dbWrapper, "Clustering")
        cleaner = new HotelLinkCleaner(oldLinks, allHotelRefs.map(h ⇒ h.id → h.partner).toMap)
        _ ← Future.traverse(Random.shuffle(gridMap.keys.toSeq)) { point ⇒
          queue.submit(cluster(point, gridMap, cleaner, metrics, transaction))
        }
      } yield ()
    }.andThen {
      case Failure(_) ⇒ queue.clear()
    }
  }
}