package ru.yandex.tours.query.parser

import java.io.OutputStream
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicInteger

import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap
import ru.yandex.tours.query.Pragmatic
import ru.yandex.tours.query.parser.ParsingTrie.Payload
import ru.yandex.tours.util.Logging
import ru.yandex.tours.util.collections.MappedArray
import ru.yandex.tours.util.io.{ByteBuffers, SegmentedByteBuffer}
import ru.yandex.tours.util.io.SegmentedByteBuffer.SegmentedOutputStream

import scala.collection.mutable
import scala.reflect.ClassTag

/**
 * Author: Vladislav Dolbilov (darl@yandex-team.ru)
 * Created: 25.01.16
 */
trait ParsingStates {
  def stateCount: Int
  def payloadCount: Int
  def pragmaticCount: Int

  def stateIds: Iterator[Int]

  def getPragmatics(stateId: Int): Array[ParsingTrie.Payload]

  def writeTo(os: OutputStream): Unit

  override def toString: String =
    s"${getClass.getSimpleName}(states = $stateCount, payloads = $payloadCount, pragmatics = $pragmaticCount)"
}

object ParsingStates {
  private def toArray[T: ClassTag](map: collection.Map[T, Int]): Array[T] = {
    val res = Array.ofDim[T](map.size)
    for ((item, idx) ← map) res(idx) = item
    res
  }

  def apply(stateToPayload: collection.Map[Int, Array[ParsingTrie.Payload]]): HeapParsingStates = {
    val payloadToId = new mutable.HashMap[Payload, Int]()
    val payloadsUsed = new AtomicInteger(0)

    val pragmaticToId = new mutable.HashMap[Pragmatic, Int]()
    val pragmaticsUsed = new AtomicInteger(0)
    for {
      (stateId, payloads) ← stateToPayload
      payload ← payloads
    } {
      payloadToId.getOrElseUpdate(payload, payloadsUsed.getAndIncrement())
      pragmaticToId.getOrElseUpdate(payload.pragmatic, pragmaticsUsed.getAndIncrement())
    }

    val payloads = toArray(payloadToId)

    new HeapParsingStates(
      stateToPayload = stateToPayload.map(p ⇒ p._1 → p._2.map(payloadToId)),
      payloadIdToLength = payloads.map(_.length),
      payloadIdToPragmaticId = payloads.map(p ⇒ pragmaticToId(p.pragmatic)),
      pragmatics = toArray(pragmaticToId)
    )
  }
}

class HeapParsingStates(stateToPayload: collection.Map[Int, Array[Int]],
                        payloadIdToLength: Array[Int],
                        payloadIdToPragmaticId: Array[Int],
                        pragmatics: Array[Pragmatic]) extends ParsingStates {

  require(payloadIdToLength.length == payloadIdToPragmaticId.length, "Inconsistent sizes")

  override def stateCount: Int = stateToPayload.size
  override def payloadCount: Int = payloadIdToLength.length
  override def pragmaticCount: Int = pragmatics.length

  override def stateIds: Iterator[Int] = stateToPayload.keysIterator

  override def getPragmatics(stateId: Int): Array[ParsingTrie.Payload] = {
    stateToPayload.getOrElse(stateId, Array.emptyIntArray).map { payloadId =>
      val length = payloadIdToLength(payloadId)
      val pragmaticId = payloadIdToPragmaticId(payloadId)
      ParsingTrie.Payload(length, pragmatics(pragmaticId))
    }
  }

  def writeTo(os: OutputStream): Unit = {
    val segmented = new SegmentedOutputStream(os)

    segmented.writeSegment("states") { os =>
      os.writeInt(stateCount)
      for ((stateId, payloadIds) <- stateToPayload) {
        os.writeInt(stateId)
        os.writeByte(payloadIds.length)
        payloadIds.foreach(os.writeInt)
      }
    }

    segmented.writeSegment("payloads") { os =>
      MappedArray.writeArray[(Int, Int)](
        payloadIdToLength zip payloadIdToPragmaticId,
        Some(Integer.BYTES * 2),
        pair => ByteBuffer.allocate(8).putInt(pair._1).putInt(pair._2).array(),
        os
      )
    }
    segmented.writeSegment("pragmatics") { os =>
      MappedArray.writeArray[Pragmatic](
        pragmatics,
        None,
        Pragmatic.write,
        os
      )
    }
  }
}

/** ByteBuffer base implementation of ParsingStates */
class MappedParsingStates(buffer: ByteBuffer) extends ParsingStates with Logging {
  private val segmented = SegmentedByteBuffer(buffer)
  log.info("Loading ParsingStates from " + segmented)

  private val stateOffsets = new Int2IntOpenHashMap()
  stateOffsets.defaultReturnValue(-1)

  {
    val buffer = segmented.getBuffer("states")
    val count = buffer.getInt

    for (_ <- 0 until count) {
      val stateId = buffer.getInt
      val pos = buffer.position()
      stateOffsets.put(stateId, pos)
      val count = buffer.get
      buffer.position(buffer.position() + count * Integer.BYTES)
    }
  }


  private val payloads = MappedArray(segmented.getBuffer("payloads")) { bb => (bb.getInt, bb.getInt) }

  private val pragmatics = MappedArray(segmented.getBuffer("pragmatics")) { Pragmatic.read }

  def stateCount: Int = stateOffsets.size()
  def payloadCount: Int = payloads.size
  def pragmaticCount: Int = pragmatics.size

  override def stateIds: Iterator[Int] = throw new NotImplementedError("MappedParsingStates.stateIds")

  override def getPragmatics(stateId: Int): Array[Payload] = {
    val stateOffset = stateOffsets.get(stateId)

    if (stateOffset == -1) {
      Array.empty
    } else {
      val b = segmented.getBuffer("states")
      b.position(stateOffset)
      val count = b.get
      val res = new Array[Payload](count)

      var i = 0
      while (i < count) {
        val payloadIdx = b.getInt
        val (length, pragmaticIdx) = payloads.get(payloadIdx)
        val pragmatic = pragmatics.get(pragmaticIdx)
        res(i) = ParsingTrie.Payload(length, pragmatic)
        i += 1
      }

      res
    }
  }

  override def writeTo(os: OutputStream): Unit = ByteBuffers.write(buffer, os)
}