package ru.yandex.tours.prices

import java.util.concurrent.atomic.AtomicReference

import org.joda.time.format.DateTimeFormat
import org.joda.time.{DateTime, LocalDate}
import ru.yandex.tours.clickhouse.ClickHouseClient
import ru.yandex.tours.geo.base.region.Tree
import ru.yandex.tours.geo.mapping.GeoMappingHolder
import ru.yandex.tours.hotels.HotelsIndex
import ru.yandex.tours.model.BaseModel.Currency
import ru.yandex.tours.model.Languages
import ru.yandex.tours.model.search.SearchProducts.{HotelSnippet, Offer}
import ru.yandex.tours.model.search.{EmptySearchFilter, HotelSearchRequest, OfferSearchRequest}
import ru.yandex.tours.util.lang.Dates._
import ru.yandex.tours.util.parsing.Tabbed
import ru.yandex.tours.util.zoo.SharedValue
import ru.yandex.tours.util.{Logging, Metrics}

import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext}
import scala.util.{Failure, Success}

/**
 * Author: Vladislav Dolbilov (darl@yandex-team.ru)
 * Created: 30.06.15
 */
class PriceHistoryStorage(clickHouseClient: ClickHouseClient, hotelsIndex: HotelsIndex,
                          tree: Tree, geoMapping: GeoMappingHolder,
                          enableWrites: SharedValue[Boolean])(implicit ec: ExecutionContext) extends Logging {

  private val bufferSize = 200000
  private val flushAfter = bufferSize - 1000

  private val buffer = new AtomicReference(new ArrayBuffer[Record](bufferSize))
  private val metrics = Metrics("price.storage")
  private val flushSize = metrics.getHistogram("flush.size")
  private val flushTimer = metrics.getTimer("flush.time")

  /*
  clickHouseClient.update(
    """create table if not exists tours.price_history (
      |  date Date,
      |  timestamp DateTime,
      |
      |  is_partial UInt8,
      |
      |  from UInt32,
      |  hotel_id UInt32,
      |  hotel_stars UInt8,
      |  regions Array(UInt32),
      |  when Date,
      |  nights UInt8,
      |  ages String,
      |  ages_array Array(UInt8),
      |
      |  min_price Int32,
      |  pansion Int8,
      |
      |  prices Nested
      |  (
      |    operator_id UInt32,
      |    pansion Int8,
      |    price Int32
      |  )
      |) ENGINE = MergeTree(date, (timestamp, hotel_id, from, ages, when, nights), 8192)
    """.stripMargin)
  */
  // alter table tours.price_history add column to UInt32 after from
  // alter table tours.price_history add column is_direction_search UInt8 after is_partial

  private val dateFormat = DateTimeFormat.forPattern("yyyy-MM-dd")
  private val dateTimeFormat = DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss")

  private case class Price(operatorId: Int, pansionId: Int, price: Int)
  private case class Record(timestamp: DateTime,
                            isPartial: Boolean,
                            isDirectionSearch: Boolean,
                            from: Int,
                            hotelId: Int,
                            to: Int,
                            when: LocalDate,
                            nights: Int,
                            ages: Array[Int],
                            prices: Array[Price]) {
    def toRowString: String = {
      val minPrice = if (prices.isEmpty) -2 else prices.minBy(_.price).price
      val minPansion = if (prices.isEmpty) -1 else prices.minBy(_.price).pansionId
      val hotel = hotelsIndex.getHotelById(hotelId)
      val regions = hotel.map(_.geoId).toSeq
        .flatMap(tree.pathToRoot).map(_.id).filter(geoMapping.isKnownDestination)

      val regions2 = if (regions.nonEmpty) regions else Seq(to)

      Tabbed(
        timestamp.toString(dateFormat),
        timestamp.toString(dateTimeFormat),
        if (isPartial) "1" else "0",
        if (isDirectionSearch) "1" else "0",
        from,
        to,
        hotelId,
        hotel.fold(0)(_.star.id),
        regions2.mkString("[", ",", "]"),
        when.toString(dateFormat),
        nights,
        ages.sorted.mkString(","),
        ages.sorted.mkString("[", ",", "]"),
        minPrice,
        minPansion,
        prices.map(_.operatorId).mkString("[", ",", "]"),
        prices.map(_.pansionId).mkString("[", ",", "]"),
        prices.map(_.price).mkString("[", ",", "]")
      )
    }
  }

  private def shouldIndex(req: HotelSearchRequest) = {
    req.from != 0 &&
      req.currency == Currency.RUB &&
      req.lang == Languages.ru &&
      req.filter == EmptySearchFilter
  }

  private def shouldIndex(offer: Offer) = {
    offer.getNights > 0 && offer.getWithFlight
  }

  private def shouldIndex(snippet: HotelSnippet): Boolean = {
    if (!snippet.getWithFlight) return false
    if (snippet.getNightsMin < 1) return false
    if (snippet.getNightsMin != snippet.getNightsMax) return false
    if (snippet.getDateMin != snippet.getDateMax) return false
    if (snippet.getSourceCount != 1) return false
    true
  }

  def +=(ts: DateTime, request: HotelSearchRequest, snippets: Seq[HotelSnippet], isPartial: Boolean): Unit = {
    if (!enableWrites.get) return
    if (!shouldIndex(request)) return
    val current = buffer.get()
    val filtered = snippets.filter(shouldIndex)
    val grouped = filtered.groupBy(s => (s.getHotelId, s.getDateMin, s.getNightsMin))
    for {
      ((hotelId, when, nights), snippets) <- grouped
    } {
      val prices = for {
        snippet <- snippets
        pansion <- snippet.getPansionsList
      } yield {
        val operatorId = snippet.getSource(0).getOperatorId
        Price(operatorId, pansion.getPansion.getNumber, pansion.getPrice)
      }
      current += Record(
        ts, isPartial,
        isDirectionSearch = true,
        request.from,
        hotelId,
        request.to,
        when.toLocalDate,
        nights,
        request.agesSerializable.toArray,
        prices.toArray
      )
    }
    if (filtered.isEmpty) {
      for {
        date <- request.dateRange
        nights <- request.nightsRange
      } {
        current += Record(
          ts, isPartial,
          isDirectionSearch = true,
          request.from,
          0,
          request.to,
          date,
          nights,
          request.agesSerializable.toArray,
          Array.empty
        )
      }
    }
    if (current.size >= flushAfter) flush()
  }

  def +=(ts: DateTime, tourRequest: OfferSearchRequest, tours: Seq[Offer], isPartial: Boolean): Unit = {
    if (!enableWrites.get) return
    val current = buffer.get()
    val request = tourRequest.hotelRequest
    if (!shouldIndex(request)) return
    val grouped = tours.filter(shouldIndex).groupBy(t => (t.getDate, t.getNights))
    for (((date, nights), tours) <- grouped) {
      val prices = for {
        offer <- tours
      } yield Price(offer.getSource.getOperatorId, offer.getPansion.getNumber, offer.getPrice)
      current += Record(
        ts, isPartial,
        isDirectionSearch = false,
        request.from,
        tourRequest.hotelId,
        request.to,
        date.toLocalDate,
        nights,
        request.agesSerializable.toArray,
        prices.toArray
      )
    }
    if (tours.isEmpty) {
      for {
        date <- request.dateRange
        nights <- request.nightsRange
      } {
        current += Record(
          ts, isPartial,
          isDirectionSearch = false,
          request.from,
          tourRequest.hotelId,
          request.to,
          date,
          nights,
          request.agesSerializable.toArray,
          Array.empty
        )
      }
    }
    if (current.size >= flushAfter) flush()
  }

  def flush(sync: Boolean = false): Unit = {
    val current = buffer.getAndSet(new ArrayBuffer[Record](bufferSize))
    if (current.isEmpty) return
    val ctx = flushTimer.time()
    val f = clickHouseClient.update(s"insert into ${clickHouseClient.database}.price_history(date, timestamp, is_partial, " +
      "is_direction_search, from, to, " +
      "hotel_id, hotel_stars, regions, when, nights, " +
      "ages, ages_array, min_price, pansion, " +
      "`prices.operator_id`, `prices.pansion`, `prices.price`) FORMAT TabSeparated",
      current.map(_.toRowString)
    ).andThen {
      case Success(_) =>
        val elapsed = ctx.stop() / 1000000
        flushSize.update(current.size)
        log.info(s"Updated ${current.size} records in $elapsed ms.")
      case Failure(ex) =>
        log.warn(s"Failed to insert ${current.size} records", ex)
    }
    if (sync) {
      Await.result(f, 1.minute)
    }
  }
}
