package ru.yandex.tours.wizard.parser

import it.unimi.dsi.fastutil.longs.{Long2DoubleOpenHashMap, Long2IntOpenHashMap}
import org.slf4j.LoggerFactory
import ru.yandex.tours.direction.DirectionsStats
import ru.yandex.tours.geo.base.region
import ru.yandex.tours.geo.base.region.Types
import ru.yandex.tours.geo.mapping.GeoMappingHolder
import ru.yandex.tours.hotels.{HotelRatings, HotelsIndex}
import ru.yandex.tours.model.filter.hotel.{GeoIdFilter, StarFilter}
import ru.yandex.tours.query._
import ru.yandex.tours.util.naming.HotelNameId
import ru.yandex.tours.wizard.WizardTracer
import ru.yandex.tours.wizard.query.ParsedUserQuery.QueryPart
import ru.yandex.tours.wizard.query._

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

/**
 * Author: Vladislav Dolbilov (darl@yandex-team.ru)
 * Created: 09.02.15
 */
object UserRequestParser {
  private val log = LoggerFactory.getLogger(getClass)
  val ignoredPragmatics: Set[Pragmatic] = Set(Ignored, StopWord)
  val QUORUM_TRESHOLD = 0.5d
}

class UserRequestParser(parser: PragmaticsParser,
                        reqAnsParser: PragmaticsParser,
                        stopWordParser: PragmaticsParser,
                        tree: region.Tree,
                        hotelIndex: HotelsIndex,
                        hotelRatings: HotelRatings,
                        geoMapping: GeoMappingHolder,
                        directionsStats: DirectionsStats) {

  import UserRequestParser.log

  def haveMarker(userRequest: String): Boolean = {
    parser.exists(userRequest, q => !UserRequestParser.ignoredPragmatics.contains(q.pragmatic))
  }

  def haveStopWords(userRequest: String): Boolean = {
    stopWordParser.findAny(userRequest)
  }

  private def stealPrepositions(part: QueryPart): QueryPart = part.pragmatic match {
    case SomeRegion(rid) =>
      part
        .stealPrefix("с ", DepartureRegion(rid))
        .stealPrefix("из ", DepartureRegion(rid))
        .stealPrefix("в ", ArrivalRegion(rid))
        .stealPrefix("по ", ArrivalRegion(rid))
        .stealPrefix("во ", ArrivalRegion(rid))
        .stealPrefix("на ", ArrivalRegion(rid))
        .stealPrefix("in ", ArrivalRegion(rid))
    case _: TourOperatorMarker =>
      part
        .stealPrefix("от ")
        .stealPrefix("в ")
        .stealPrefix("у ")
        .stealPrefix("на ")
        .stealPrefix("через ")
    case _: HotelName =>
      part.stealPrefix("в ")
    case _ => part
  }

  private def guessRegionDirection(query: ParsedUserQuery): (QueryPart) => QueryPart = {
    val withHotelMarker = query.has[HotelMarker]
    val withParsedHotel = query.has[HotelName]
    val withTourOperator = query.has[TourOperatorMarker]
    val withArrival = query.has[ArrivalRegion]
    part => {
      part.pragmatic match {
        case SomeRegion(rid) =>
          val region = tree.region(rid)
          val parents = for {
            region <- region.toSeq
            parent <- tree.pathToRoot(region)
          } yield parent
          val inRussia = parents.exists(_.id == 225)
          val newPragmatic =
            if (inRussia) {
              if (withHotelMarker || withParsedHotel) {
                if (log.isDebugEnabled) log.debug(s"Promoted $part to arrival: have hotel")
                ArrivalRegion(rid)
              } else if (withArrival) {
                if (log.isDebugEnabled) log.debug(s"Promoted $part to departure: have arrival")
                DepartureRegion(rid)
              } else if (geoMapping.isKnownDestination(rid) && !geoMapping.isDepartureCity(rid)) {
                if (log.isDebugEnabled) log.debug(s"Promoted $part to arrival: is known destination")
                ArrivalRegion(rid)
              } else if (region.exists(_.`type` <= Types.FederalSubject)) {
                if (log.isDebugEnabled) log.debug(s"Promoted $part to arrival: region type")
                ArrivalRegion(rid)
              } else if (withTourOperator) {
                if (log.isDebugEnabled) log.debug(s"Promoted $part to departure: with tour operator")
                DepartureRegion(rid)
              } else {
                if (log.isDebugEnabled) log.debug(s"Promoted $part to arrival: in russia")
                ArrivalRegion(rid)
              }
            } else {
              ArrivalRegion(rid)
            }
          part.copy(pragmatic = newPragmatic)
        case _ => part
      }
    }
  }

  private def isKnownDestination(part: QueryPart) = part.pragmatic match {
    case ArrivalRegion(regionId) => geoMapping.isKnownDestination(regionId)
    case _ => true
  }

  private def unionRegions(query: ParsedUserQuery): QueryPart => Seq[QueryPart] = {
    def isUnder(region1: Int, region2: Int) = tree.pathToRoot(region1).map(_.id).contains(region2)
    val regions = query.queryParts.filter(_.pragmatic.isInstanceOf[GeoRegion])
    part => {
      part.pragmatic match {
        case arrival: ArrivalRegion =>
          regions.filter(_.collide(part)).filter(_.isAfter(part)).flatMap { collide =>
            val rid = collide.pragmatic.asInstanceOf[GeoRegion].regionId
            if (isUnder(rid, arrival.regionId)) Seq(part.union(collide, ArrivalRegion(rid)))
            else if (isUnder(arrival.regionId, rid)) Seq(part.union(collide, ArrivalRegion(arrival.regionId)))
            else Seq.empty
          } :+ part
        case _ => Seq(part)
      }
    }
  }

  private def resolveRegionConflicts(query: ParsedUserQuery): QueryPart => Seq[QueryPart] = {
    def getPriority(qp: QueryPart): Double = {
      directionsStats.getPriority(qp.pragmatic.asInstanceOf[ArrivalRegion].regionId)
    }
    val toDelete = query.queryParts.filter(_.pragmatic.isInstanceOf[ArrivalRegion])
      .groupBy(qp => (qp.startPosition, qp.endPosition)).filter(_._2.size > 1)
      .flatMap { case (_, group) =>
        val bestRegionPart = group.maxBy(qp => getPriority(qp))
        group.filterNot(_ == bestRegionPart)
      }.toSet

    part => {
      if (toDelete.contains(part)) Seq.empty
      else Seq(part)
    }
  }

  private def parseHotel(query: ParsedUserQuery): QueryPart => Seq[QueryPart] = {
    val stars = query.collectOf[Stars].map(_.count).toSet
    def filterHotels(hotelIds: Iterable[Long]): Iterable[Long] = {
      if (stars.isEmpty) hotelIds
      else {
        val matched = hotelIndex.filter(hotelIds.map(new HotelNameId(_).hotelId), Seq(new StarFilter(stars)), None)
        hotelIds.filter(n => matched.contains(new HotelNameId(n).hotelId))
      }
    }

    val regions = query.collectOf[ArrivalRegion].map(_.regionId) ++ query.collectOf[SomeRegion].map(_.regionId)
    def filterByDestination(hotelIds: Iterable[Long]): Iterable[Long] = {
      if (regions.isEmpty) hotelIds
      else {
        val matched = hotelIndex.filter(hotelIds.map(new HotelNameId(_).hotelId), Seq(new GeoIdFilter(regions)), None)
        hotelIds.filter(n => matched.contains(new HotelNameId(n).hotelId))
      }
    }

    val hotelNameParts = query.queryParts.filter(_.pragmatic.isInstanceOf[HotelNamePart])
      .groupBy(_.slice).map(_._2.head)

    val strongParts = hotelNameParts.groupBy(qp => (qp.startPosition, qp.length)).values.collect {
      case Vector(qp) => qp
    }

    val (bloomHotelNameParts, mapHotelNameParts) = hotelNameParts
      .partition(_.pragmatic.isInstanceOf[HotelNamePartBloom])

    val bestHotelNameFromParts: Option[HotelNameId] = {
      val scores = new Long2DoubleOpenHashMap()
      val quorum = new Long2DoubleOpenHashMap()
      var boost = Map.empty[Long, Double]

      val ratings = hotelRatings.copy()
      def relevance(name: Long) = {
        val hotelId = new HotelNameId(name).hotelId
        (quorum.get(name) * scores.get(name) * boost.getOrElse(name, 1d), ratings.getVisits(hotelId))
      }

      val totalPenalty = 1d / 2
      val antiPenalty = totalPenalty / hotelNameParts.size
      val bloomBonus = antiPenalty + 1d / 4 // magic number of words
      val bloomSize = bloomHotelNameParts.size

      val (bloomFilters, bloomScores, bloomThresholds) = {
        val q = bloomHotelNameParts.zipWithIndex.map { case (part, i) =>
          val threshold =
            if (!log.isDebugEnabled) UserRequestParser.QUORUM_TRESHOLD - ((bloomSize - i) * bloomBonus)
            else -totalPenalty
          (part.pragmatic.asInstanceOf[HotelNamePartBloom], part.score / 10, threshold)
        }.unzip3
        (q._1.toArray, q._2.toArray, q._3.toArray)
      }

      for {
        part <- mapHotelNameParts
        HotelNamePartMap(hotelIds) <- Option(part.pragmatic)
        score = part.score / (hotelIds.size min 10)
      } {
        val it = hotelIds.long2IntEntrySet().fastIterator()
        while (it.hasNext) {
          val e = it.next()
          val nameId = e.getLongKey
          val words = e.getIntValue
          scores.addTo(nameId, score)
          quorum.addTo(nameId, 1d / words + antiPenalty)
        }
      }

      var ids = ArrayBuffer.empty[Long]
      for (e <- quorum.long2DoubleEntrySet().fastIterator.asScala) {
        val nameId = e.getLongKey
        var value = e.getDoubleValue - totalPenalty
        if (strongParts.forall { part =>
          part.pragmatic.asInstanceOf[HotelNamePart].contains(nameId)
        }) {
          var i = 0
          while (i < bloomSize && value > bloomThresholds(i)) {
            if (bloomFilters(i).contains(nameId)) {
              value += bloomBonus
              scores.addTo(nameId, bloomScores(i))
            }
            i += 1
          }
        } else {
          value = 0
        }
        e.setValue(value)
        if (value > UserRequestParser.QUORUM_TRESHOLD) {
          ids += nameId
        }
      }

      for (id <- filterHotels(ids)) {
        boost += id -> boost.getOrElse(id, 1d) * 2.0d
      }
      val byDestination = filterByDestination(ids).toSet
      ids = ids.filter(byDestination.contains)

      if (log.isDebugEnabled) {
        ids.sortBy(relevance).reverse.take(20).foreach { id =>
          log.debug("  " + new HotelNameId(id).toString + " -> " + quorum.get(id) + " " + scores.get(id) +
            " " + boost.getOrElse(id, 1d))
        }
      }

      def bestHotel(ids: Iterable[Long]) = new HotelNameId(ids.maxBy(relevance))
      if (ids.nonEmpty) Some(bestHotel(ids))
      else None
    }

    part => {
      part.pragmatic match {
        case p: HotelNamePart if bestHotelNameFromParts.exists(n => p.contains(n.id)) =>
          Seq(part.copy(pragmatic = HotelName(bestHotelNameFromParts.get.hotelId), boost = 1.4d))
        case _: HotelNamePart =>
          Seq.empty
        case _ =>
          Seq(part)
      }
    }
  }

  def parse(userRequest: String, parsedParts: Seq[ParsedUserQuery.QueryPart]): ParsedUserQuery = {
    val fullRequest = reqAnsParser.parse(userRequest)
      .filter(p ⇒ p.startPosition == 0 && p.length == p.userRequest.length)
    WizardTracer.checkpoint("parsed_reqans")

    if (fullRequest.containHoles) {
      val rawParts = parser.parse(userRequest)
      if (!rawParts.exists(_ != Ignored)) {
        return rawParts.copy(queryParts = Vector.empty)
      }

      val parts = (rawParts + parsedParts).filter(_.atWordStart)
      WizardTracer.checkpoint("parsed_pragmatics")

      if (log.isDebugEnabled) {
        log.debug(s"Raw parts [$userRequest]:")
        parts.queryParts.foreach { part =>
          log.debug(part.toString)
        }
      }

      val withPrepositions = parts.map(stealPrepositions)

      if (withPrepositions.containHoles) {
        return withPrepositions.selectBestMatch.fillHoles
      }

      val preprocessed = withPrepositions
        .map2(guessRegionDirection)
        .filter(isKnownDestination)
        .flatMap2(unionRegions)
        .flatMap2(resolveRegionConflicts)
        .flatMap2(parseHotel)
      WizardTracer.checkpoint("preprocessed_pragmatics")

      if (log.isDebugEnabled) {
        log.debug(s"Preprocessed request [$userRequest]:")
        preprocessed.queryParts.foreach { part =>
          log.debug(part.toString)
        }
      }

      preprocessed
        .selectBestMatch
        .fillHoles
    } else {
      fullRequest
    }
  }
}
