package ru.yandex.tours.db.dao

import java.sql.Timestamp

import org.joda.time.DateTime
import org.xerial.snappy.Snappy
import ru.yandex.tours.db.dao.HotelsDao.{InGridPoint, SkipDeleted, WithIds}
import ru.yandex.tours.db.model._
import ru.yandex.tours.db.tables.{HotelCoordinates, HotelUrls, Hotels}
import ru.yandex.tours.db.{DBWrapper, GridPoint, HotelStableIds, Target}
import ru.yandex.tours.geo
import ru.yandex.tours.geo.Geohash
import ru.yandex.tours.model.BaseModel.Point
import ru.yandex.tours.model.MapRectangle
import ru.yandex.tours.model.hotels.HotelsHolder.PartnerHotel
import ru.yandex.tours.model.hotels.Partners._
import ru.yandex.tours.model.util.Paging
import ru.yandex.tours.util.Logging
import ru.yandex.tours.util.collections.RafBasedMap
import slick.driver.MySQLDriver.api._
import slick.jdbc.TransactionIsolation

import scala.concurrent.{ExecutionContext, Future}
import scala.collection.mutable
import scala.util.{Failure, Success, Try}

/**
 * Author: Vladislav Dolbilov (darl@yandex-team.ru)
 * Created: 22.03.16
 */
class HotelsDao(db: DBWrapper)(implicit ec: ExecutionContext) extends Logging {

  db.createIfNotExists(Hotels.table)
  db.createIfNotExists(HotelUrls.table)
  db.createIfNotExists(HotelCoordinates.table)

  private def emptyQuery: Query[Hotels, DbPartnerHotel, Seq] = Hotels.table


  def retrieveRafMap(queries: HotelsDao.Query*)
                    (implicit ec: ExecutionContext = this.ec): Future[RafBasedMap[Int, PartnerHotel]] = {
    val map = new RafBasedMap[Int, PartnerHotel](
      hotel => Snappy.compress(hotel.toByteArray),
      array => PartnerHotel.parseFrom(Snappy.uncompress(array))
    )
    val f = retrieveRafMapInternal(map, 0, queries: _*)
    f.onFailure {
      case e => map.close()
    }
    f.map { _ =>
      val queriesString = if (queries.isEmpty) "<empty>" else queries.mkString(", ")
      log.info(s"Got ${map.size} hotels with filter $queriesString")
      map.freeze()
    }
  }

  private def retrieveRafMapInternal(map: RafBasedMap[Int, PartnerHotel],
                                     offset: Int,
                                     queries: HotelsDao.Query*)
                    (implicit ec: ExecutionContext): Future[Unit] = {
    val limit = 50000
    var query = queries.foldLeft(emptyQuery) { (t, transformer) => transformer.apply(t) }
    query = query.sortBy(_.id).drop(offset).take(limit)
    val qr = query.result
    db.run(qr).flatMap(
      hotels => {
        for (h <- hotels) {
          map += h.id -> h.hotel
        }
        if (hotels.size < limit) {
          Future.successful(Unit)
        } else {
          retrieveRafMapInternal(map, offset + limit, queries: _*)
        }
      }
    )
  }

  def getHotelIds(queries: HotelsDao.Query*): Future[Seq[Int]] = {
    val ids = mutable.Buffer[Int]()
    getHotelIdsInternal(ids, 0, queries: _*).map(_ => ids)
  }

  private def getHotelIdsInternal(retrievedIds: mutable.Buffer[Int],
                                  offset: Int, queries: HotelsDao.Query*): Future[Unit] = {
    val limit = 300000
    var q = queries.foldLeft(emptyQuery) { (t, transformer) => transformer.apply(t) }
    q = q.drop(offset).take(limit)
    val q1 = q.sortBy(_.id).take(limit).map(h => h.id).result
    db.run(q1).flatMap(
      v => {
        retrievedIds.appendAll(v)
        if (v.size < limit) {
          Future.successful(Unit)
        } else {
          getHotelIdsInternal(retrievedIds, offset + limit, queries: _*)
        }
      }
    )
  }



  def get(id: Int): Future[Option[DbPartnerHotel]] = {
    val q = Hotels.table.filter(_.id === id)
    db.run(q.result).map(_.headOption)
  }

  def get(ids: Iterable[Int]): Future[Seq[DbPartnerHotel]] = {
    val q = Hotels.table.filter(_.id.inSet(ids))
    db.run(q.result)
  }

  def get(paging: Paging, queries: HotelsDao.Query*): Future[Seq[DbPartnerHotel]] = {
    val q = queries.foldLeft(emptyQuery) { (t, transformer) => transformer.apply(t) }
      .drop(paging.pageSize * paging.page)
      .take(paging.pageSize)

    db.run(q.result)
  }

  def getNear(point: Point): Future[RafBasedMap[Int, PartnerHotel]] = {
    val gridPoints = GridPoint.fromPoint(point).nearPoints
    retrieveRafMap(InGridPoint(gridPoints: _*), SkipDeleted)
  }

  def count(queries: HotelsDao.Query*): Future[Int] = {
    val q = queries.foldLeft(emptyQuery) { (t, transformer) => transformer.apply(t) }
    db.run(q.size.result)
  }

  def traverse[U](queries: HotelsDao.Query*)(callback: (DbPartnerHotel) => U): Future[Unit] = {
    val q = queries.foldLeft(emptyQuery) { (t, transformer) => transformer.apply(t) }

    val action = q.result
      .transactionally
      .withTransactionIsolation(TransactionIsolation.ReadCommitted)

    db.stream(action).foreach { hotel =>
      callback(hotel)
    }
  }

  def traverseAll[U](callback: (DbPartnerHotel) => U): Future[Unit] = {
    traverse()(callback)
  }

  def insert(hotels: Iterable[PartnerHotel]): Future[Unit] = {
    val insertAction = Hotels.table ++= hotels.map(DbPartnerHotel.apply)
    val q = DBIO.seq(
      insertAction,
      indexAction(hotels)
    ).transactionally
    db.run(q)
  }

  private def cleanAction(hotels: Iterable[PartnerHotel]) = {
    val ids = hotels.map(_.getId)
    DBIO.seq(
      HotelCoordinates.table.filter(_.id.inSet(ids)).delete,
      HotelUrls.table.filter(_.id.inSet(ids)).delete
    )
  }

  private def indexAction(hotels: Iterable[PartnerHotel]) = {
    val points = Seq.newBuilder[(Int, Point, GridPoint)]
    val urls = Seq.newBuilder[(Int, String)]

    hotels.foreach { hotel =>
      if (hotel.getRawHotel.hasPoint) {
        points += Tuple3(hotel.getId, hotel.getRawHotel.getPoint, GridPoint.fromPoint(hotel.getRawHotel.getPoint))
      }
      if (hotel.getRawHotel.hasPartnerUrl) {
        urls += hotel.getId -> hotel.getRawHotel.getPartnerUrl
      }
    }
    DBIO.seq(
      HotelCoordinates.table ++= points.result(),
      HotelUrls.table ++= urls.result()
    )
  }

  private def updateAction(hotels: Iterable[PartnerHotel]) = {
    Hotels.batchUpdate(hotels)
  }

  def update(hotel: PartnerHotel): Future[Unit] = {
    update(Seq(hotel))
  }

  def update(hotels: Iterable[PartnerHotel]): Future[Unit] = {
    val sorted = hotels.toSeq.sortBy(_.getId)
    val q = DBIO.seq(
      cleanAction(sorted),
      indexAction(sorted),
      updateAction(sorted)
    ).transactionally
    db.run(q)(Target.master)
  }

  /** internal api; exposed for tools */
  def _updateIndex(hotels: Iterable[PartnerHotel]): Future[Unit] = {
    val q = DBIO.seq(
      cleanAction(hotels),
      indexAction(hotels)
    ).transactionally
    db.run(q)
  }

  def publish(hotelIds: Iterable[Int]): Future[Int] = {
    val q = for (h <- Hotels.table if h.id.inSet(hotelIds)) yield h.isNew
    db.run(q.update(false))
  }

}

object HotelsDao {
  type Q = slick.driver.MySQLDriver.api.Query[Hotels, DbPartnerHotel, Seq]

  sealed trait Query {
    protected[dao] def apply(q: Q): Q
  }

  case class UpdatedBetween(start: DateTime, end: DateTime) extends Query {
    override protected[dao] def apply(q: Q): Q = {
      q.filter {
        h => h.updated >= new Timestamp(start.getMillis) && h.updated < new Timestamp(end.getMillis)
      }
    }
  }

  case object SkipDeleted extends Query {
    override protected[dao] def apply(q: Q): Q = {
      q.filter(_.isDeleted === false)
    }
  }

  case class WithIds(ids: Iterable[Int]) extends Query {
    override protected[dao] def apply(q: Q): Q = {
      q.filter { h ⇒
        ids.grouped(1000).foldLeft(false: Rep[Boolean]) {
          case (c, ids) ⇒ c || h.id.inSet(ids)
        }
      }
    }

    override def toString: String = s"WithIds(${ids.size} ids: ${ids.take(5)}...)"
  }

  case class OnlyPartner(partner: Partner) extends Query {
    override protected[dao] def apply(q: Q): Q = {
      q.join(HotelStableIds.table).on(_.id === _.id)
        .filter(_._2.partner === partner.id)
        .map(_._1)
    }
  }

  case class IsNew(isNew: Boolean) extends Query {
    override protected[dao] def apply(q: Q): Q = {
      q.filter(_.isNew === isNew)
    }
  }

  case class InSpan(span: MapRectangle) extends Query {
    private val approx = Geohash.approximate(span)

    private def in(geohash: Rep[Long]): Rep[Boolean] = {
      approx.foldLeft(false: Rep[Boolean]) {
        case (c, (min, max)) ⇒ c || geohash.between(min, max)
      }
    }

    override protected[dao] def apply(q: Q): Q = {
      q.join(HotelCoordinates.table).on(_.id === _.id)
        .filter(p ⇒ in(p._2.geohash))
        .map(_._1)
    }
  }

  case class InGridPoint(points: GridPoint*) extends Query {
    override protected[dao] def apply(q: Q): Q = {
      val ids = points.map(_.id)
      q.join(HotelCoordinates.table).on(_.id === _.id)
        .filter(_._2.gridIndex.inSet(ids))
        .map(_._1)
    }
  }

  case class WithUrl(url: String) extends Query {
    override protected[dao] def apply(q: Q): Q = {
      q.join(HotelUrls.table).on(_.id === _.id)
        .filter(_._2.url === url)
        .map(_._1)
    }
  }
}