package ru.yandex.tours.util.collections

import java.io.{DataOutputStream, OutputStream}
import java.nio.ByteBuffer

import org.apache.commons.io.output.ByteArrayOutputStream

import scala.collection.mutable.ArrayBuffer

/**
 * <h2>Buffer format</h2>
 *  4 bytes - int - array size
 *  1 byte - boolean - records with fixed size
 *
 *  if records have fixed size:
 *  4 bytes - int – size of records in bytes
 *
 *  if records have dynamic size:
 *  4 * N bytes – array[int] - record offsets
 *
 *  rest of buffer:
 *  bytes - i-th item serialized
 */
class MappedArray[T](buffer: ByteBuffer, deserializer: ByteBuffer => T) {

  val size: Int = buffer.getInt()
  val isFixedSizeRecords: Boolean = buffer.get > 0

  def isEmpty: Boolean = size == 0

  private val recordSize: Int = if (isFixedSizeRecords) buffer.getInt else -1
  private val headerRecordSize: Int = (if (isFixedSizeRecords) 0 else 1) * Integer.BYTES

  private val header: ByteBuffer = {
    val b = buffer.slice()
    b.limit(size * headerRecordSize)
    b
  }

  private val data: ByteBuffer = {
    buffer.position(buffer.position() + size * headerRecordSize)
    buffer.slice()
  }

  private def offsetFor(idx: Int) = {
    if (isFixedSizeRecords) idx * recordSize
    else if (idx == size) data.capacity()
    else header.getInt(idx * headerRecordSize)
  }

  def get(idx: Int): T = {
    val offset = offsetFor(idx)
    val limit = offsetFor(idx + 1)

    val b = data.duplicate()
    b.position(offset).limit(limit)
    deserializer(b.slice())
  }
}

object MappedArray {
  def apply[T](buffer: ByteBuffer)(deserializer: ByteBuffer => T): MappedArray[T] = {
    new MappedArray[T](buffer, deserializer)
  }

  def intArray(buffer: ByteBuffer): MappedArray[Int] = new MappedArray[Int](buffer, _.getInt)

  def writeArray[T](arr: Array[T], fixedSize: Option[Int], serializer: T => Array[Byte], os: OutputStream): Unit = {
    val offsets = new ArrayBuffer[Int](arr.length)

    val data = {
      val os = new ByteArrayOutputStream()
      var offset = 0
      for (item <- arr) {
        offsets += offset
        val bytes = serializer(item)
        os.write(bytes)
        offset += bytes.length
      }
      os
    }

    val daos = new DataOutputStream(os)
    daos.writeInt(arr.length)
    daos.writeByte(if (fixedSize.isDefined) 1 else 0)
    fixedSize.foreach(size => daos.writeInt(size))

    if (fixedSize.isEmpty) {
      for (i <- arr.indices) {
        daos.writeInt(offsets(i))
      }
    }

    data.writeTo(daos)
  }
}