package ru.yandex.tours.geo

import ru.yandex.tours.model.MapRectangle
import ru.yandex.tours.model.geo.MapObject
import ru.yandex.tours.model.BaseModel.Point

/**
 * Author: Vladislav Dolbilov (darl@yandex-team.ru)
 * Created: 19.02.16
 */
object Geohash {

  private def bit(value: Double, min: Double, max: Double): (Byte, Double, Double) = {
    val mean = (max + min) / 2
    if (value >= mean) (1, mean, max)
    else (0, min, mean)
  }

  private def select(bit: Byte, min: Double, max: Double) = {
    val mean = (max + min) / 2
    if (bit > 0) (mean, max)
    else (min, mean)
  }

  def encode(lat: Double, lon: Double): Long = {
    var result = 0L

    var (latMin, latMax) = (-90d, 90d)
    var (lonMin, lonMax) = (-180d, 180d)

    for (i <- 0 until 32) {
      val (lonBit, newLonMin, newLonMax) = bit(lon, lonMin, lonMax)
      lonMin = newLonMin
      lonMax = newLonMax
      result = result * 2 + lonBit

      val (latBit, newLatMin, newLatMax) = bit(lat, latMin, latMax)
      latMin = newLatMin
      latMax = newLatMax
      result = result * 2 + latBit
    }
    result + Long.MinValue
  }

  def encodePoint(point: Point): Long = {
    encode(point.getLatitude, point.getLongitude)
  }

  def encode(obj: MapObject): Long = encode(obj.latitude, obj.longitude)

  def decode(geohash: Long): (Double, Double) = {
    val hash = geohash - Long.MinValue
    var (latMin, latMax) = (-90d, 90d)
    var (lonMin, lonMax) = (-180d, 180d)

    for (i <- 31 to 0 by -1) {
      val lonBit = hash >> (i * 2 + 1) & 1
      val (newLonMin, newLonMax) = select(lonBit.toByte, lonMin, lonMax)
      lonMin = newLonMin
      lonMax = newLonMax
      val latBit = hash >> (i * 2) & 1
      val (newLatMin, newLatMax) = select(latBit.toByte, latMin, latMax)
      latMin = newLatMin
      latMax = newLatMax
    }

    ((latMin + latMax) / 2, (lonMin + lonMax) / 2)
  }

  def decodePoint(geohash: Long): Point = {
    val (lat, lon) = decode(geohash)
    Point.newBuilder().setLatitude(lat).setLongitude(lon).build()
  }


  def approximate(rect: MapRectangle, maxLevel: Int = 20): Seq[(Long, Long)] = {
    def doApprox(lower: Long, upper: Long, level: Int): Seq[MapRectangle] = {
      val (minLat, minLon) = decode(lower)
      val (maxLat, maxLon) = decode(upper)
      val bound = MapRectangle.byBoundaries(minLon, minLat, maxLon, maxLat)

      if (rect.contains(bound)) Seq(bound)
      else if (level >= maxLevel) Seq(bound)
      else if (bound.intersect(rect)) {
        val mid = lower + upper / 2 - lower / 2
        Seq(
          doApprox(lower, mid, level + 1),
          doApprox(mid + 1, upper, level + 1)
        ).flatten
      } else {
        Seq.empty
      }
    }

    val sorted = doApprox(Long.MinValue, Long.MaxValue, level = 0)
      .map(r ⇒ encode(r.minLat, r.minLon) → encode(r.maxLat, r.maxLon))
      .sorted

    if (sorted.isEmpty) return sorted

    val res = Seq.newBuilder[(Long, Long)]
    var lastMin = sorted.head._1
    var lastMax = sorted.head._2
    var write = true
    sorted.tail.foreach {
      case (min, max) ⇒
        if (min <= lastMax + 1) {
          lastMax = max
          write = true
        } else {
          res += lastMin → lastMax
          write = false
          lastMin = min
          lastMax = max
        }
    }
    if (write) {
      res += lastMin → lastMax
    }
    res.result()
  }
}
