package ru.yandex.tours.util

import _root_.spray.http.ContentTypes._
import _root_.spray.http.HttpHeaders.Location
import _root_.spray.http._
import _root_.spray.httpx.marshalling.Marshaller
import _root_.spray.routing.Directives._
import _root_.spray.routing._
import com.google.protobuf.Message
import com.googlecode.protobuf.format.{HtmlFormat, JsonFormat}
import org.apache.commons.io.output.ByteArrayOutputStream
import org.json.{JSONArray, JSONObject}
import ru.yandex.util.spray.Protobuf
import shapeless.HList

/* @author berkut@yandex-team.ru */

package object spray extends Logging {

  implicit class RichUri(uri: Uri) {
    def /(segment: String): Uri = uri.copy(path = uri.path / segment)
    def /(segment: Long): Uri = uri.copy(path = uri.path / segment.toString)
  }

  def completeJsonOk(json: JSONArray): RequestContext => Unit = {
    completeJson(StatusCodes.OK, json, new JSONArray())
  }

  def completeJsonRedirect(uri: Uri, json: JSONArray): RequestContext => Unit = {
    respondWithHeader(Location(uri)) {
      completeJson(StatusCodes.PermanentRedirect, json, new JSONArray())
    }
  }

  def completeJsonRedirect(uri: Uri, json: JSONObject): RequestContext => Unit = {
    completeJsonRedirect(uri, toArray(json))
  }

  def completeJsonOk(json: JSONObject): RequestContext => Unit = {
    completeJsonOk(toArray(json))
  }

  def completeJsonError(sc: StatusCode, json: JSONArray): RequestContext => Unit = {
    completeJson(sc, new JSONArray(), json)
  }

  def completeJsonError(sc: StatusCode, error: String): RequestContext => Unit = {
    val ar = new JSONArray()
    ar.put(error)
    completeJsonError(sc, ar)
  }

  private def completeJson(statusCode: StatusCode, data: JSONArray, errors: JSONArray): RequestContext => Unit = {
    respondWithMediaType(MediaTypes.`application/json`) {
      val result = new JSONObject().put("data", data).put("errors", errors)
      complete(statusCode, result.toString)
    }
  }

  def completeErrorWithMessage(message: String, e: Throwable)(http: RequestContext): Unit = {
    val seed = Randoms.nextString(16)
    log.error(seed + ": " + message, e)
    completeJsonError(StatusCodes.InternalServerError, s"$message! See logs $seed")(http)
  }

  private implicit val messageMarshaller: Marshaller[Message] =
    Marshaller.of(
      Protobuf.contentType,
      `application/octet-stream`, NoContentType,
      `application/json`, ContentType(MediaTypes.`text/html`)
    ) {
      case (msg, ct, ctx) =>
        if (ct == `application/json`) {
          ctx.marshalTo(HttpEntity(ct, JsonFormat.printToString(msg)))
        } else if (ct.mediaType == MediaTypes.`text/html`) {
          ctx.marshalTo(HttpEntity(ct, HtmlFormat.printToString(msg)))
        } else {
          ctx.marshalTo(HttpEntity(`application/octet-stream`, msg.toByteArray))
        }
    }

  def completeProto(message: Message): RequestContext => Unit = {
    completeProto(Some(message))
  }

  def completeProto(optMessage: Option[Message]): RequestContext => Unit = {
    optMessage match {
      case Some(message) => complete(message)
      case None => complete(HttpEntity(`application/octet-stream`, Array.emptyByteArray))
    }
  }

  def completeProtoSeq(messages: Seq[Message]): RequestContext => Unit = {
    respondWithMediaType(`application/octet-stream`.mediaType) {
      val bytes = IO.using(new ByteArrayOutputStream()) { os =>
        messages.foreach(_.writeDelimitedTo(os))
        os.toByteArray
      }
      complete(bytes)
    }
  }

  def responseToStatusCode(response: Any): StatusCode = {
    response match {
      case HttpResponse(status, _, _, _) => status
      case Rejected(rejections) => StatusCodes.BadRequest
      case Confirmed(ChunkedResponseStart(HttpResponse(status, _, _, _)), _) => status
      case Confirmed(MessageChunk(_, _), _) => StatusCodes.Continue
      case ChunkedMessageEnd(_, _) => StatusCodes.Continue
      case other =>
        log.debug(s"Unknown message $other")
        StatusCodes.ServiceUnavailable
    }
  }

  def extract[T](query: Uri.Query, directive: Directive1[T]): T = {
    hextract(query, directive).head
  }

  def hextract[T <: HList](query: Uri.Query, directive: Directive[T]): T = {
    val sealedDirective = directive happly { extracted =>
      return extracted
    }
    sealedDirective {
      new RequestContext(
        HttpRequest(HttpMethods.GET, Uri("/").withQuery(query)),
        null,
        Uri.Path.Empty
      ) {
        override def reject(rejection: Rejection): Unit = sys.error(s"$query not matched directive: $rejection")
        override def reject(rejections: Rejection*): Unit = sys.error(s"$query not matched directive: ${rejections.mkString(", ")}")
      }
    }
    sys.error(s"$query not matched directive")
  }

  private def toArray(x: JSONObject) = new JSONArray().put(x)
}
