package ru.yandex.tours.util.collections

import java.io.OutputStream
import java.nio.ByteBuffer

import ru.yandex.tours.util.io.ByteBuffers

import scala.collection.mutable.ArrayBuffer

/**
  * Created by asoboll on 08.02.17.
  */
trait ByteBufferMappedMap extends AbstractMappedMap {
  protected val HEADER_SIZE = 4 * Integer.BYTES

  protected class ByteBufferData[V](buffer: ByteBuffer)
                                   (implicit dV: ByteBuffer => V) extends AbstractData[V]{
    def get(pos: Position): V = dV(ByteBuffers.part(buffer, pos.offset, pos.limit))
  }

  protected case class ByteBufferPair(size: Int,
                                      dataBuffer: ByteBuffer,
                                      offsetsBuffer: ByteBuffer,
                                      indexBuffer: ByteBuffer) {
    def offsets: AbstractOffsets =
      if (offsetsBuffer.capacity() == 0) {
        val recordSize = if (size != 0) dataBuffer.capacity() / size else 0
        new FixedSizeOffsets(recordSize)
      }
      else new SeqBasedOffsets(MappedSeq.from[Int](offsetsBuffer)(_.getInt))

    def data[V](implicit dV: ByteBuffer => V): AbstractData[V] =
      new ByteBufferData[V](dataBuffer)
  }

  protected object ByteBufferPair {
    def apply(buffer: ByteBuffer): ByteBufferPair = {
      val header = ByteBuffers.part(buffer, buffer.limit - HEADER_SIZE, buffer.limit)

      val size = header.getInt
      val dataSize = header.getInt
      val offsetsSize = header.getInt
      val indexSize = header.getInt

      def take(n: Int): ByteBuffer = {
        val b = buffer.slice
        buffer.position(buffer.position + n)
        b.limit(n)
        b.slice
      }

      require(buffer.limit == dataSize + offsetsSize + indexSize + HEADER_SIZE,
        s"Buffer limit ${buffer.limit} doesnt match header ($size, $dataSize, $offsetsSize, $indexSize)")
      ByteBufferPair(size, take(dataSize), take(offsetsSize), take(indexSize))
    }
  }

  object MappedSeq {
    def from[V](buffer: ByteBuffer)
               (implicit d: ByteBuffer => V): MappedSeq[V] = {
      val buffers@ByteBufferPair(size, dataBuffer, _, indexBuffer) = ByteBufferPair(buffer)
      require(indexBuffer.capacity() == 0, "MappedSeq should not have index")

      new MappedSeq[V](size, buffers.offsets, buffers.data)
    }
  }

  protected abstract class MappedWriter[V](os: OutputStream) {
    protected var size = 0
    protected var dataLength = 0

    protected var recordSize: Option[Int] = None
    protected val offsets = ArrayBuffer.empty[Int]

    protected var finished = false

    protected def checkState(): Unit =
      if (finished) throw new IllegalStateException("Already finished writing")

    protected def updateStats(length: Int) = {
      dataLength += length
      if (offsets.nonEmpty) {            // not fixed length
        offsets += dataLength
      } else recordSize match {
        case Some(l) if l == length =>   // fixed length
        case None =>                     // init fixed length (first record)
          recordSize = Some(length)
        case Some(l) =>                  // init not fixed length (doesnt match)
          recordSize = None
          offsets ++= (0 to size).map(_ * l) += dataLength
      }
      size += 1
    }

    protected def writeBytes(bytes: Array[Byte]): Unit = {
      os.write(bytes)
      updateStats(bytes.length)
    }

    protected def writeInternal(item: V): Unit

    def write(item: V): Unit = {
      checkState()
      writeInternal(item)
    }

    def write(items: TraversableOnce[V]): Unit = {
      checkState()
      items.foreach(writeInternal)
    }

    protected def writeOffsets(): Int = {
      if (offsets.nonEmpty) {
        def serializeInt(x: Int) = ByteBuffer.allocate(Integer.BYTES).putInt(x).array()
        val offsetWriter = new MappedSeqWriter[Int](os)(serializeInt)
        offsetWriter.write(offsets)
        offsetWriter.finish()
      } else 0
    }

    protected def writeIndex(): Int = 0

    protected def writeHeader(offsetLength: Int, indexLength: Int): Unit = {
      val header = ByteBuffer.allocate(HEADER_SIZE)
        .putInt(size)
        .putInt(dataLength)
        .putInt(offsetLength)
        .putInt(indexLength)
        .array()
      os.write(header)
    }

    //returns written size
    def finish(): Int = {
      checkState()
      val offsetLength = writeOffsets()
      val indexLength = writeIndex()
      writeHeader(offsetLength, indexLength)
      finished = true
      dataLength + offsetLength + indexLength + HEADER_SIZE
    }
  }

  class MappedSeqWriter[V](os: OutputStream)(implicit serializer: V => Array[Byte])
    extends MappedWriter[V](os) {

    protected def writeInternal(item: V): Unit = {
      val bytes = serializer(item)
      writeBytes(bytes)
    }
  }

  class MappedMapWriter[K, V](os: OutputStream)(implicit sK: K => Array[Byte], sV: V => Array[Byte])
    extends MappedWriter[(K, V)](os) {

    private val index = ArrayBuffer.empty[Array[Byte]]

    protected def writeInternal(item: (K, V)): Unit = {
      val (key, value) = item
      index += sK(key)
      val bytes = sV(value)
      writeBytes(bytes)
    }

    override protected def writeIndex(): Int = {
      if (index.nonEmpty) {
        val indexWriter = new MappedSeqWriter[Array[Byte]](os)(identity)
        indexWriter.write(index)
        indexWriter.finish()
      } else 0
    }
  }

  object MappedMap {
    def from[K, V](buffer: ByteBuffer)
                  (implicit ord: Ordering[K],
                   dK: ByteBuffer => K,
                   dV: ByteBuffer => V): MappedMap[K, V] = {
      val buffers@ByteBufferPair(size, dataBuffer, _, indexBuffer) = ByteBufferPair(buffer)

      val index = if (size > 0) {
        require(indexBuffer.capacity() > 0, "non empty MappedMap should have index")
        new SeqBasedIndex[K](MappedSeq.from[K](indexBuffer))
      } else emptyIndex[K]
      new MappedMap[K, V](size, index, buffers.offsets, buffers.data)
    }
  }
}

object ByteBufferMappedMap extends ByteBufferMappedMap
