package ru.yandex.tours.tools.merging

import java.io._

import akka.actor.ActorSystem
import akka.stream.ActorMaterializer
import com.google.common.io.CountingInputStream
import com.typesafe.config.ConfigFactory
import ru.yandex.tours.db.GridPoint
import ru.yandex.tours.hotels.clustering.features._
import ru.yandex.tours.hotels.clustering.{ClusteringContext, HotelContext, LocalContext}
import ru.yandex.tours.model.hotels.HotelsHolder.PartnerHotel
import ru.yandex.tours.tools.Tool
import ru.yandex.tours.util.IO
import ru.yandex.tours.util.parsing.Tabbed

import scala.collection.mutable
import scala.util.Random

object CorpusProcessorTool extends Tool with CorpusAware {

  implicit val akka = ActorSystem("corpus-processor-tool", ConfigFactory.empty())
  implicit val materializer = ActorMaterializer()

  val features = Seq(
    new NameTfIdfCosFeature(LocalHotelsIndex.tfIdf),
    LocalTfIdfCosFeature,
    AddressNumberFeature,
    AddressTfIdfCosFeature,
    SamePartnerFeature,
    LocalDensityFeature,
    HotelTypesFeature,
    BoundedDistanceFeature,
    PHashNearestFeature,
    PHashAucFeature,
    CountOfSimilarOnPHashFeature,
    NNFeaturesNearestCosFeature,
    NNFeaturesCosAucFeature,
    LevenshteinFeature,
    NameSetFeature,
    PhoneSuffixFeature,
    NameShingleFeature,
    UrlHostFeature,
    StarsFeature,
    IsApartAnyFeature,
    IsApartBoothFeature,
    IsBackaFeature
  )

  case class HotelRef(id: Int) extends AnyVal

  val header = Tabbed("id1", "id2", "is_same", "corpus_name", features.map(_.name).mkString("\t"))

  val offsets = new mutable.HashMap[Int, Long]()
  val gridMap = new mutable.HashMap[GridPoint, Seq[HotelRef]]()

  println("Starting")
  IO.using(new CountingInputStream(new FileInputStream("corpus_hotels.proto"))) { in ⇒
    var offset = in.getCount
    var hotel = PartnerHotel.parseDelimitedFrom(in)

    while (hotel ne null) {
      offsets += (hotel.getId → offset)
      val point = GridPoint.fromPoint(hotel.getRawHotel.getPoint)
      gridMap.put(point, gridMap.getOrElse(point, Seq.empty) :+ HotelRef(hotel.getId))

      offset = in.getCount
      hotel = PartnerHotel.parseDelimitedFrom(in)
    }
  }

  val rafFile = new RandomAccessFile("corpus_hotels.proto", "r")

  def containsHotel(id: Int): Boolean = offsets.contains(id)
  def getHotel(id: Int): Option[PartnerHotel] = {
    offsets.get(id) match {
      case None ⇒ None
      case Some(offset) ⇒
        rafFile.seek(offset)
        val hotel = PartnerHotel.parseDelimitedFrom(new InputStream {
          override def read(): Int = rafFile.read()
          override def read(b: Array[Byte], off: Int, len: Int): Int = rafFile.read(b, off, len)
        })
        Some(hotel)
    }
  }

  println(s"${offsets.size} hotels loaded")

  val falses = mutable.Buffer.empty[String]
  val trues = mutable.Buffer.empty[String]

  var i = 0
  val started = System.currentTimeMillis()

  def ctxFor(a: PartnerHotel, b: PartnerHotel, localContext: LocalContext): ClusteringContext = {
    ClusteringContext(
      HotelContext.apply(a, localContext),
      HotelContext.apply(b, localContext)
    )
  }

  IO.printFile("ipython/combined_corpus.tsv") { combinedPw =>
    combinedPw.println(header)
    for (corpus <- allCorpuses) {
      println(s"Processing corpus $corpus")
      IO.printFile(s"ipython/${corpus.name}") { pw =>
        pw.println(header)
        val linksWithNear = getHotelLinks(containsHotel, getHotel, corpus)
          .groupBy(l => GridPoint.fromPoint(l.from.getRawHotel.getPoint))
          .iterator
          .flatMap { case (point, links) =>

            val near = point.nearPoints.flatMap(gridMap.getOrElse(_, Seq.empty))
            val localCtx = new LocalContext(near.iterator.flatMap(ref ⇒ getHotel(ref.id)).filterNot(_.getIsDeleted))
            println(s"Loaded $point: ${near.length} hotels, ${links.size} links")
            links.map(_ -> localCtx)
          }

        linksWithNear
          .foreach {
            case (LinkWithHotel(from, to, isMerge, _corpus), near) =>
              val ctx = ctxFor(from, to, near)
              val line = Tabbed(
                from.getId,
                to.getId,
                if (isMerge) 1 else 0,
                _corpus.name,
                features.map(_.apply(ctx)).mkString("\t")
              )
              i += 1
              if (i % 100 == 0) {
                val elapsed = System.currentTimeMillis() - started
                println(s"Processed $i pairs in ${elapsed / 1000} seconds")
                println(s"Average processing time = ${elapsed / i} ms. per line")
              }
              if (isMerge) {
                trues += line
              } else {
                falses += line
              }
              pw.println(line)
              combinedPw.println(line)
          }
      }
    }
  }

  val random = new Random(0)

  private def writeSplit(lines: Seq[String], trainPw: PrintWriter, holdoutPw: PrintWriter) = {
    val splitRate = 0.7
    val (train, holdout) = random.shuffle(lines).splitAt((splitRate * lines.length).toInt)
    train.foreach(trainPw.println)
    holdout.foreach(holdoutPw.println)
  }


  IO.printFile("ipython/training_combined_corpus.tsv") { trainPw =>
    IO.printFile("ipython/holdout_combined_corpus.tsv") { holdoutPw =>
      trainPw.println(header)
      holdoutPw.println(header)
      writeSplit(falses, trainPw, holdoutPw)
      writeSplit(trues, trainPw, holdoutPw)
    }
  }


  println("Done")
  sys.exit()
}
