package ru.yandex.tours.indexer.clusterization

import java.util.concurrent.Callable
import java.util.concurrent.atomic.{AtomicInteger, AtomicLong}

import akka.actor.ActorSystem
import akka.stream.ActorMaterializer
import akka.stream.scaladsl.{Sink, Source}
import com.google.common.cache.{CacheBuilder, CacheLoader}
import ru.yandex.tours.clustering.Clustering.LinkWithConfidence
import ru.yandex.tours.geo
import ru.yandex.tours.hotels.clustering.{ClusteringContext, ClusteringModel, HotelContext, LocalContext}
import ru.yandex.tours.model.hotels.HotelsHolder.PartnerHotel
import ru.yandex.tours.util.Logging
import ru.yandex.tours.util.akka.Streams

import scala.concurrent.{ExecutionContext, Future}

/**
 * Author: Vladislav Dolbilov (darl@yandex-team.ru)
 * Created: 06.05.16
 */
class MatrixnetClusterizer(masters: Grid[PartnerHotel],
                           slaves: Grid[PartnerHotel],
                           parallelism: Int,
                           clusteringModel: ClusteringModel)
                          (implicit ec: ExecutionContext, as: ActorSystem)
  extends Clusterizer with Logging {

  val CONTEXT_DISTANCE = 3d

  private implicit val materializer = ActorMaterializer()
  private val started = System.nanoTime()
  private val processed = new AtomicInteger()
  private val processedPairs = new AtomicLong()
  private val links = new AtomicInteger()
  private val merged = new AtomicInteger()
  private val notMerged = new AtomicInteger()

  def getMergeResult(slaves: Iterator[PartnerHotel]): Future[Seq[LinkWithConfidence]] = {
    Source.fromIterator(() => slaves)
      .mapAsyncUnordered(parallelism)(processSlave)
      .via(Streams.flatten)
      .runWith(Sink.seq)
  }

  private def processSlave(slave: PartnerHotel): Future[Seq[LinkWithConfidence]] = Future {
    val result = getSimilarHotels(slave)
    links.addAndGet(result.size)
    if (processed.incrementAndGet() % 500 == 0) {
      log.info(s"Clusterization done for ${processed.get()}. " +
        s"Merged: ${merged.get}, not merged: ${notMerged.get()}, links: $links")
    }
    if (result.isEmpty) {
      notMerged.incrementAndGet()
      Seq.empty
    } else {
      merged.incrementAndGet()
      buildLink(slave, result)
    }
  }

  private def buildLink(slave: PartnerHotel, masters: Seq[(PartnerHotel, Double)]) = {
    masters.map {
      case (m, conf) => LinkWithConfidence(parent = m.getId, child = slave.getId, conf)
    }
  }

  private def getSimilarHotels(slave: PartnerHotel): Seq[(PartnerHotel, Double)] = {
    val (lonIndex, latIndex) = Grid.getIndex(slave.getRawHotel.getPoint)
    val inPoint = cache.get((lonIndex, latIndex))
    val nearHotels = getNear(slave, inPoint, distance = CONTEXT_DISTANCE * 2)
    val candidates = getNear(slave, nearHotels, distance = CONTEXT_DISTANCE)
    val localContext = new LocalContext(getNear(slave, inPoint).iterator)
    val slaveContext = HotelContext(slave, localContext)
    candidates
      .filter(_.getId != slave.getId)
      .map(hotel => hotel -> similarity(slaveContext, hotel, nearHotels))
      .filter(_._2 >= 0.5d)
  }

  private def similarity(slave: HotelContext, master: PartnerHotel, near: Seq[PartnerHotel]): Double = {
    val masterContext = contextCache.get(master.getId, new Callable[HotelContext] {
      override def call(): HotelContext = {
        val localContext = new LocalContext(getNear(master, near).iterator)
        HotelContext(master, localContext)
      }
    })
    val ctx = ClusteringContext(slave, masterContext)
    val result = clusteringModel.apply(ctx)
    logPair()
    result
  }

  private def logPair(): Unit = {
    val processed = processedPairs.incrementAndGet()
    val processedHotels = this.processed.get()
    if (processed % 5000 == 0) {
      val elapsed = System.nanoTime() - started
      log.info(s"Processed $processed pairs in ${elapsed / 1e9.toLong} seconds")
      log.info(f"Average processing time = ${elapsed.toDouble / processed / 1e6}%.2f ms. per pair, " +
        f"${elapsed.toDouble / processedHotels / 1e6}%.2f ms. per hotel")
    }
  }

  private def getNear(hotel: PartnerHotel, hotels: Seq[PartnerHotel], distance: Double = CONTEXT_DISTANCE) = {
    hotels.filter(h => geo.distanceInKm(h.getRawHotel.getPoint, hotel.getRawHotel.getPoint) <= distance)
  }

  private val contextCache = CacheBuilder.newBuilder().maximumSize(10000).build[Integer, HotelContext]()

  private val cache = CacheBuilder.newBuilder().maximumSize(3).build(new CacheLoader[(Int, Int), Seq[PartnerHotel]] {
    override def load(key: (Int, Int)): Seq[PartnerHotel] = {
      val (lonIndex, latIndex) = key
      for {
        lon <- lonIndex - 1 to lonIndex + 1
        lat <- latIndex - 1 to latIndex + 1
        hotel <- masters.get(lon, lat).toArray ++ slaves.get(lon, lat)
      } yield hotel
    }
  })
}

class MatrixnetClusterizerFactory(clusteringModel: ClusteringModel)
                                 (implicit akkaSystem: ActorSystem) extends ClusterizerFactory {
  override def apply(master: Grid[PartnerHotel], slave: Grid[PartnerHotel]): Clusterizer = {
    val ec = akkaSystem.dispatchers.lookup("akka.actor.hotels-clustering-dispatcher")
    val parallelism = 4
    new MatrixnetClusterizer(master, slave, parallelism, clusteringModel)(ec, akkaSystem)
  }
}