package ru.yandex.tours.util.spray.stream

import akka.actor.FSM.Failure
import akka.actor.{ActorRef, FSM, Props}
import akka.io.Tcp
import akka.stream.actor.ActorSubscriberMessage.{OnComplete, OnError, OnNext}
import akka.stream.actor.{ActorSubscriber, RequestStrategy, ZeroRequestStrategy}
import ru.yandex.common.actor.logging.ActorLoggingSlf4jOverriding
import ru.yandex.tours.util.spray.stream.MarshallerActor._
import spray.http._
import spray.httpx.marshalling.{DelegatingMarshallingContext, MarshallingContext}

import scala.reflect.ClassTag

/**
 * Stream sink which consumes elements,
 * composes them and sends them as chunked HTTP responses.
 *
 * @param renderer renderer for consumed values
 * @param ctx marshalling context for write HTTP responses to
 * @param batchSize number of consumed elements to be sent in a single chunk
 *
 * @author dimas
 */
private[stream] class MarshallerActor[T: ClassTag](renderer: Renderer[T], ctx: MarshallingContext, batchSize: Int)
  extends ActorSubscriber
  with FSM[State, Vector[T]]
  with ActorLoggingSlf4jOverriding {

  private var channel: ActorRef = _
  private var pendingWrite = false

  private val marshaller = Renderer.renderer2seqMarshaller(renderer)

  private val chunkingCtx: MarshallingContext =
    new DelegatingMarshallingContext(ctx) {
      override def marshalTo(entity: HttpEntity, headers: HttpHeader*): Unit = {
        log.debug(s"Sending ${entity.data}...")
        pendingWrite = true
        if (channel == null) {
          channel = ctx.startChunkedMessage(entity, Some(Ack), headers)
        } else {
          channel ! MessageChunk(entity.data).withAck(Ack)
        }
      }

      override def handleError(error: Throwable): Unit = {
        stop(Failure(error))
        ctx.handleError(error)
      }

      override def startChunkedMessage(entity: HttpEntity,
                                       sentAck: Option[Any],
                                       headers: Seq[HttpHeader])
                                      (implicit sender: ActorRef) =
        sys.error("Cannot marshal a stream of streams")
    }

  @throws[Exception](classOf[Exception])
  override def preStart(): Unit = {
    super.preStart()

    log.debug(s"Starting actor $self for stream marshalling")

    if (renderer.header.nonEmpty) {
      log.debug("Sending response header...")
      pendingWrite = true
      val entity = HttpEntity(renderer.contentType, renderer.header)
      channel = ctx.startChunkedMessage(entity, Some(Ack))
    } else {
      self ! Ack
    }
  }

  // Will request next elements from source manually
  protected def requestStrategy: RequestStrategy =
    ZeroRequestStrategy

  startWith(Running, Vector.empty[T])

  when(Running) {
    case Event(Ack, _) =>
      log.debug(s"Request next $batchSize elements from source...")
      pendingWrite = false
      request(batchSize)
      stay()

    case Event(OnNext(value: T), batch) =>
      val updated = batch :+ value
      if (updated.size == batchSize) {
        marshaller(updated, chunkingCtx)
        stay() using Vector.empty[T]
      } else {
        stay() using updated
      }

    case Event(OnError(e), _) =>
      log.error("Error while marshalling stream response", e)
      chunkingCtx.handleError(e)
      stop(Failure(e))

    case Event(OnComplete, batch) =>
      log.debug("Stream completed. Will complete response...")
      if (batch.nonEmpty)
        marshaller(batch, chunkingCtx)
      else if (!pendingWrite)
        self ! Ack
      else
        log.debug("Stream completed. Write command in progress...")
      goto(SendingFooter)
  }

  when(SendingFooter) {
    case Event(Ack, _) =>
      log.debug("Sending response footer...")
      if (renderer.footer.nonEmpty) {
        channel ! MessageChunk(HttpData(renderer.footer)).withAck(Ack)
      } else {
        self ! Ack
      }
      goto(SendingChunkedMessageEnd)
  }

  when(SendingChunkedMessageEnd) {
    case Event(Ack, _) =>
      log.debug("Sending chunked response end.")
      channel ! ChunkedMessageEnd
      stop()
  }

  whenUnhandled {
    case Event(_: Tcp.ConnectionClosed, _) =>
      log.warning("TCP connection closed. Stop streaming.")
      stop()

    case Event(other, _) =>
      log.error(s"Unexpected message $other received. Stop streaming.")
      stop()
  }

  override def postStop(): Unit = {
    log.debug(s"Stop stream marshaller $self")
    super.postStop()
  }
}

object MarshallerActor {

  sealed trait State

  case object Running extends State

  case object SendingFooter extends State

  case object SendingChunkedMessageEnd extends State

  def props[T: ClassTag](renderer: Renderer[T], ctx: MarshallingContext, batchSize: Int): Props = {
    Props(new MarshallerActor(renderer, ctx, batchSize))
  }

  /** Chunk sent acknowledgement message */
  private case object Ack

}