package ru.yandex.tours.indexer.hotels

import java.io.{File, FileInputStream, OutputStream}
import java.util.concurrent.atomic.AtomicInteger
import java.util.zip.GZIPOutputStream

import akka.actor.ActorSystem
import akka.stream.ActorMaterializer
import akka.stream.scaladsl.Source
import ru.yandex.extdata.loader.engine.DataPersistenceManager
import ru.yandex.tours.backa.BackaPermalinks
import ru.yandex.tours.db.DBWrapper
import ru.yandex.tours.db.dao.HotelsDao
import ru.yandex.tours.db.dao.HotelsDao.IsNew
import ru.yandex.tours.db.tables.Clusterization.LinkWithInfo
import ru.yandex.tours.db.tables.{Clusterization, HotelAmendments, LinkType}
import ru.yandex.tours.extdata.DataTypes
import ru.yandex.tours.geo.base.region.Tree
import ru.yandex.tours.hotels.amendings.HotelAmending
import ru.yandex.tours.hotels.enrichers.GeoIdByPartnerHotelSetter
import ru.yandex.tours.hotels._
import ru.yandex.tours.indexer.task.{AsyncUpdatable, TaskWeight}
import ru.yandex.tours.model.hotels.HotelsHolder.{PartnerHotel, ProtoHotelAmendments, TravelHotel}
import ru.yandex.tours.util.collections.Graph
import ru.yandex.tours.util.{IO, Logging, ProtoIO, Statistics}

import scala.collection.mutable
import scala.concurrent.duration.FiniteDuration
import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Failure, Success}

class HotelsIndexer(db: DBWrapper,
                    hotelsDao: HotelsDao,
                    geoIdSetter: GeoIdByPartnerHotelSetter,
                    tree: Tree,
                    hotelRatings: HotelRatings,
                    backaPermalinks: BackaPermalinks,
                    dataPersistenceManager: DataPersistenceManager,
                    updateTime: FiniteDuration)
                   (implicit ec: ExecutionContext, akkaSystem: ActorSystem)
  extends AsyncUpdatable(updateTime, "index_hotels") with TaskWeight.Heavy {

  private val parallelism = 4
  private val confidence = 0.7
  private implicit val actorMaterializer = ActorMaterializer()

  override protected def update: Future[_] = retrieveTravelHotels().map { file =>
    try {
      for (shard <- 0 until ShardedYoctoHotelsIndex.SHARDS_COUNT) {
        // TODO get rid of this strange 24x reading of the same file...
        val toIndex = ProtoIO.loadFromFile(file, TravelHotel.PARSER).filter(getShardId(_, tree) == shard)
        val index = IO.usingTmp(s"hotel_shard_$shard") { os =>
          IO.using(new GZIPOutputStream(os)) { gzipOs =>
            YoctoHotelsIndex.build(toIndex, tree, hotelRatings, gzipOs)
          }
        }
        try {
          val dt = DataTypes.shardedHotels(shard)
          dataPersistenceManager.checkAndStore(dt, new FileInputStream(index))
        } finally IO.deleteFile(index)
      }
    } finally IO.deleteFile(file)
  } andThen {
    case Success(_) => log.info("HotIdx: Successfully indexed hotels!")
    case Failure(e) => log.warn("HotIdx: Can not index hotels!", e)
  }

  def getShardId(hotel: TravelHotel, tree: Tree): Int = {
    val geoId = hotel.getGeoId
    val path = tree.pathToRoot(geoId).zipWithIndex
    val key = path.find(_._1.isCountry) match {
      case None =>
        log.debug(s"Hotel ${hotel.getId} don't have country in path to root!")
        0
      case Some((country, index)) if index == 0 => country.id
      case Some((country, index)) => path(index - 1)._1.id
    }
    key % ShardedYoctoHotelsIndex.SHARDS_COUNT
  }

  def retrieveTravelHotels(): Future[File] = {
    for {
      hotelIds <- getHotelsIds
      clusterLinks <- Clusterization.retrieveClusterLinks(confidence, db, "HotIdx")
      graph = buildGraph(clusterLinks)
      clusters = graph.getConnectedComponents
      result <- buildTravelHotels(hotelIds, clusters)
    } yield {
      result
    }
  }

  private def getHotelsIds: Future[Seq[Int]] = {
    val result = hotelsDao.getHotelIds(IsNew(false))
    Statistics.asyncLogTime("HotIdx: Hotels ids retrieving", result)
  }

  private def buildGraph(clusters: Seq[LinkWithInfo]) = {
    Statistics.logTime("HotIdx: Build graph") {
      val set = mutable.Set.empty[(Int, Int)]
      clusters.sortBy(l => (l.isManual, l.timestamp)).foreach {
        case LinkWithInfo(from, to, _, linkType, _, _) =>
          linkType match {
            case LinkType.MERGE => set += from -> to
            case LinkType.UNMERGE => set -= from -> to
          }
      }
      new Graph(set)
    }
  }

  private def buildTravelHotels(hotelIds: Seq[Int], clusters: Seq[Set[Int]]): Future[File] = {
    val success = new AtomicInteger()

    Statistics.asyncLogTime("HotIdx: Building travel hotels", {
      log.info(s"HotIdx: Have ${hotelIds.size} hotelIds and ${clusters.size} NonSingle clusters")

      val nonSingleHotelIds = clusters.flatten.toSet

      log.info(s"HotIdx: Have ${nonSingleHotelIds.size} hotels in NonSingle clusters")

      def singleHotelIds = hotelIds.iterator.filterNot(id => nonSingleHotelIds.contains(id))

      val singleClusters = singleHotelIds.map(id => Set(id))

      log.info(s"HotIdx: Have ${singleClusters.size} Single clusters")

      val allClusters = clusters ++ singleClusters

      log.info(s"HotIdx: Have ${hotelIds.size} hotelIds, ${allClusters.size} clusters = (${clusters.size} NonSingle + ${singleHotelIds.size} Single)")
      IO.usingAsyncTmp("travel_hotel_retriever") { os => processAllClusters(allClusters, os) }
    })
  }

  private def processAllClusters(allClusters: Seq[Set[Int]], os: OutputStream): Future[Unit] = {
    var allClusterGroups = mutable.Buffer.empty[Seq[Set[Int]]]
    var clusterGroup = mutable.Buffer.empty[Set[Int]]
    var groupHotelCount = 0
    def flush() : Unit = {
      if (clusterGroup.nonEmpty) {
        allClusterGroups.append(clusterGroup)
        clusterGroup = mutable.Buffer.empty[Set[Int]]
        groupHotelCount = 0
      }
    }
    for (cluster <- allClusters) {
      clusterGroup.append(cluster)
      groupHotelCount += cluster.size
      if (groupHotelCount > 5000) {
        flush()
      }
    }
    flush()

    val doneClusters = new AtomicInteger()
    val doneTravelHotels = new AtomicInteger()
    Source.fromIterator(() => allClusterGroups.iterator)
      .mapAsync(parallelism)(group => processClusterGroup(group, doneClusters))
      .runForeach{ hotels =>
        for (h <- hotels) {
          h.writeDelimitedTo(os)
          doneTravelHotels.incrementAndGet()
        }
        log.info(s"HotIdx: Processed clusters: $doneClusters/${allClusters.size}, produced $doneTravelHotels TravelHotels")
      }
  }

  private def processClusterGroup(clustersGroup: Seq[Set[Int]],
                                  doneClusters: AtomicInteger): Future[Seq[TravelHotel]] = {
    val hotelIds = clustersGroup.flatten
    for {
      dbHotels <- hotelsDao.get(hotelIds)
      amendments <- HotelAmendments.getAmendments(db, hotelIds)
    } yield {
      val hotelsMap = mutable.HashMap.empty[Int, PartnerHotel]
      for (dbHotel <- dbHotels) {
        hotelsMap.put(dbHotel.id, dbHotel.hotel)
      }
      log.info(s"HotIdx: For ${hotelIds.size} hotelIds got ${dbHotels.size} DB Hotels and ${amendments.size} amendments")
      val result = mutable.Buffer.empty[TravelHotel]
      for (cluster <- clustersGroup) {
        val clusterHotels = hotelsMap.filter(kv => cluster.contains(kv._1)).values
        val travelHotel = buildCluster(clusterHotels, amendments)
        if (travelHotel.isDefined && HotelsIndex.isIndexable(travelHotel.get)) {
          result.append(travelHotel.get)
        }
        doneClusters.incrementAndGet()
      }
      result
    }
  }

  private def buildCluster(cluster: Iterable[PartnerHotel],
                    id2amendments: Map[Int, ProtoHotelAmendments]): Option[TravelHotel] = {
    val amendments = cluster.flatMap { hotel =>
      val amendments = id2amendments.get(hotel.getId)
      amendments.toIterable.flatMap(HotelAmending.parseProtos)
    }
    TravelHotelBuilder.buildTravelHotel(
      cluster,
      hotelRatings,
      geoIdSetter,
      tree,
      amendments,
      backaPermalinks
    )
  }
}
