package ru.yandex.tours.util.naming

import java.io.{DataInputStream, DataOutputStream, InputStream, OutputStream}

import breeze.linalg.{SparseVector, VectorBuilder}
import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap
import ru.yandex.extdata.common.meta.DataType
import ru.yandex.tours.extdata.{DataDef, DataTypes}
import ru.yandex.tours.util.IO

import scala.util.hashing.MurmurHash3

/**
 * Author: Vladislav Dolbilov (darl@yandex-team.ru)
 * Created: 20.04.16
 */
class TfIdfModel(private val maxIndex: Int, private val documentCount: Int, wordFreq: Int2IntOpenHashMap) {

  protected def wordCount(wordIndex: Int): Int = wordFreq.get(wordIndex)

  private val docCountDouble = documentCount.toDouble

  def tfIdf(doc: Seq[String]): SparseVector[Double] = {
    val vector = new VectorBuilder[Double](maxIndex)

    for (word <- doc) {
      val idx = TfIdfModel.indexOf(word, maxIndex)
      val cnt = wordCount(idx)
      if (cnt > 0) {
        vector.add(idx, math.log(docCountDouble / cnt).toFloat / doc.size)
      }
    }

    vector.toSparseVector
  }

  def saveTo(os: OutputStream): Unit = {
    IO.using(new DataOutputStream(os)) { data =>
      data.writeInt(maxIndex)
      data.writeInt(documentCount)
      data.writeInt(wordFreq.size())
      val it = wordFreq.int2IntEntrySet().fastIterator()
      while (it.hasNext) {
        val e = it.next()
        data.writeInt(e.getIntKey)
        data.writeInt(e.getIntValue)
      }
    }
  }
}

object TfIdfModel {
  private def indexOf(word: String, maxIndex: Int): Int = {
    (MurmurHash3.stringHash(word) % maxIndex).abs
  }

  def parse(is: InputStream): TfIdfModel = {
    IO.using(new DataInputStream(is)) { data =>
      val maxIndex = data.readInt()
      val documentCount = data.readInt()
      val size = data.readInt()
      val wordFreq = new Int2IntOpenHashMap(size)

      for (i <- 0 until size) {
        val idx = data.readInt()
        val count = data.readInt()
        wordFreq.put(idx, count)
      }
      new TfIdfModel(maxIndex, documentCount, wordFreq)
    }
  }

  def newBuilder(maxWords: Int = 1 << 19): Builder = new Builder()

  class Builder(maxWords: Int = 1 << 19) {
    private val wordFreq = new Int2IntOpenHashMap()
    private var documentCount = 0

    def += (doc: Seq[String]): Unit = {
      documentCount += 1
      for (word <- doc) {
        val idx = indexOf(word, maxWords)
        wordFreq.addTo(idx, 1)
      }
    }

    def result(): TfIdfModel = new TfIdfModel(maxWords, documentCount, wordFreq)
  }

  def build(documents: TraversableOnce[Seq[String]], maxWords: Int = 1 << 19): TfIdfModel = {
    val wordFreq = new Int2IntOpenHashMap()
    var documentCount = 0

    for (doc <- documents) {
      documentCount += 1
      for (word <- doc) {
        val idx = indexOf(word, maxWords)
        wordFreq.addTo(idx, 1)
      }
    }
    new TfIdfModel(maxWords, documentCount, wordFreq)
  }

  def merge(model1: TfIdfModel, model2: TfIdfModel): TfIdfModel = {
    require(model1.maxIndex == model2.maxIndex)
    new TfIdfModel(model1.maxIndex, model1.documentCount + model2.documentCount, new Int2IntOpenHashMap()) {
      override protected def wordCount(wordIndex: Int): Int = {
        model1.wordCount(wordIndex) + model2.wordCount(wordIndex)
      }
      override def saveTo(os: OutputStream): Unit = {
        sys.error("Cannot write merged TfIdfModel")
      }
    }
  }
}

object HotelsTfIdfModel extends DataDef[TfIdfModel] {
  override def dataType: DataType = DataTypes.hotelsTfIdfModel
  override def parse(is: InputStream): TfIdfModel = TfIdfModel.parse(is)
}