package ru.yandex.tours.indexer.hotels

import java.io.{File, OutputStream}
import java.util.concurrent.atomic.AtomicInteger

import akka.actor.ActorRefFactory
import akka.stream.ActorMaterializer
import akka.stream.scaladsl.Source
import ru.yandex.tours.avatars.AvatarClient
import ru.yandex.tours.model.BaseModel.ProtoImage
import ru.yandex.tours.model.Image
import ru.yandex.tours.model.hotels.HotelsHolder.PartnerHotel
import ru.yandex.tours.model.hotels.Partners.Partner
import ru.yandex.tours.model.image.ImageProviders
import ru.yandex.tours.model.image.ImageProviders.ImageProvider
import ru.yandex.tours.partners.BookingHttp
import ru.yandex.tours.util.collections.RafBasedMap
import ru.yandex.tours.util.lang.{Futures, _}
import ru.yandex.tours.util.{IO, Logging, ProtoIO}
import ru.yandex.tours.util.Collections._

import scala.collection.JavaConversions._
import scala.concurrent.{ExecutionContext, Future}
import scala.util.Try

class ImageEnricher(avatarClient: AvatarClient,
                    shouldDownloadNew: Boolean = true)
                   (implicit ec: ExecutionContext, actorRefFactory: ActorRefFactory) extends Logging {

  private implicit val materializer = ActorMaterializer()
  private val parallelism = 4

  /**
   *
   * @param file with [[PartnerHotel]]
   * @param partner - which all hotels belong to
   * @param imageProvider - which should be used while downloading photos
   * @param oldHotels - map with old hotels to decrease load to avatar client
   * @return file with delimited [[PartnerHotel]] with unified images
   */
  def enrich(file: File,
             partner: Partner,
             imageProvider: ImageProvider,
             oldHotels: RafBasedMap[Int, PartnerHotel]): Future[File] = {
    IO.usingAsyncTmp("image_enricher") { os =>
      for {
        hotels <- Try(ProtoIO.loadFromFile(file, PartnerHotel.PARSER)).toFuture
        result <- loadImages(hotels, oldHotels, os, imageProvider)
      } yield result
    }
  }

  private class Statistics {
    val downloaded: AtomicInteger = new AtomicInteger()
    val recovered: AtomicInteger = new AtomicInteger()
    val skipped: AtomicInteger = new AtomicInteger()
    val processed: AtomicInteger = new AtomicInteger()
    val failed: AtomicInteger = new AtomicInteger()

    def update(urls: Seq[String],
               unknown: Seq[String],
               needRecover: Seq[String],
               newPhotos: Seq[Image],
               exceptions: Seq[Throwable]): Unit = {

      val names = urls.flatMap(getPossibleNames).toSet
      val unknownNames = unknown.flatMap(getPossibleNames).toSet
      val recoverNames = needRecover.flatMap(getPossibleNames).toSet

      downloaded.addAndGet(newPhotos.count(i => unknownNames.contains(i.name)))
      recovered.addAndGet(newPhotos.count(i => recoverNames.contains(i.name)))
      failed.addAndGet(exceptions.size)
      processed.addAndGet(urls.size)
      skipped.addAndGet(urls.size - unknown.size)
    }

    def getPossibleNames(url: String): Seq[String] = {
      Seq(Image.name(url), Image.name(BookingHttp.decreasePhotoQuality(url)))
    }

    override def toString: String = {
      s"$downloaded downloaded. $recovered recovered. $skipped skipped. $failed failed. $processed total processed."
    }
  }

  private def loadImages(hotels: Iterator[PartnerHotel],
                         oldHotels: RafBasedMap[Int, PartnerHotel],
                         os: OutputStream,
                         imageProvider: ImageProvider): Future[Unit] = {
    var processedHotels = 0
    val statistics = new Statistics()

    def write(hotel: PartnerHotel) = {
      processedHotels += 1
      if (processedHotels % 1000 == 0) {
        log.info(s"Image enricher processed $processedHotels hotels from $imageProvider. $statistics")
      }
      hotel.writeDelimitedTo(os)
    }

    Source.fromIterator(() => hotels)
      .mapAsync(parallelism) { hotel =>
        val rawPhotos = hotel.getRawHotel.getRawImagesList
        val oldHotelVersion = oldHotels.get(hotel.getId)
        getPhotos(imageProvider, rawPhotos, oldHotelVersion, statistics).map {
          photos => hotel.toBuilder.addAllImages(photos).build()
        }
      }.runForeach(write)
  }

  private def downloadPhotos(urls: Seq[String], imageProvider: ImageProvider): Future[(Seq[Image], Seq[Throwable])] = {
    val futures = urls.map { url =>
      val result = avatarClient.put(url, imageProvider)
      if (imageProvider == ImageProviders.booking) {
        result.recoverWith {
          case e =>
            avatarClient.put(BookingHttp.decreasePhotoQuality(url), imageProvider)
        }
      } else {
        result
      }.map(i => if (i.provider == imageProvider) i else i.copy(provider = imageProvider))
    }
    Futures.partitionSequence(futures)
  }

  private def getPhotos(imageProvider: ImageProvider,
                        urls: Seq[String],
                        oldHotel: Option[PartnerHotel],
                        statistics: Statistics): Future[Seq[ProtoImage]] = {
    val isBooking = ImageProviders.booking == imageProvider

    def buildImageMap(images: Iterable[ProtoImage]) = images.map(i => i.getName -> i).toMap
    val oldPhotos = oldHotel match {
      case Some(hotel) =>
        hotel.getImagesList.toIndexedSeq.map {
          case image if image.getProviderId == imageProvider.id => image
          case image => image.toBuilder.setProviderId(imageProvider.id).build()
        }
      case None => Seq.empty[ProtoImage]
    }

    if (shouldDownloadNew) {
      val oldPhotosMap = buildImageMap(oldPhotos)

      val (known, unknown) = urls.partition { url =>
        oldPhotosMap.contains(Image.name(url)) ||
          (isBooking && oldPhotosMap.contains(Image.name(BookingHttp.decreasePhotoQuality(url))))
      }
      val needRecover = known.filter { url =>
        val photo = oldPhotosMap.get(Image.name(url))
          .orElse(oldPhotosMap.get(Image.name(BookingHttp.decreasePhotoQuality(url))))

        photo.exists(Image.fromProto(_).isNeedRecover) || photo.exists(_.hasNNetFeaturesRaw)
      }
      downloadPhotos(unknown ++ needRecover, imageProvider).map { case (newPhotos, exceptions) =>
        statistics.update(urls, unknown, needRecover, newPhotos, exceptions)

        val allPhotos = oldPhotosMap ++ buildImageMap(newPhotos.map(_.copy(provider = imageProvider).toProto))
        val result = urls.flatMap { url =>
          var result = allPhotos.get(Image.name(url))
          if (isBooking && result.isEmpty) {
            result = allPhotos.get(Image.name(BookingHttp.decreasePhotoQuality(url)))
          }
          result
        }

        if (exceptions.nonEmpty && log.isDebugEnabled) {
          val messagesCount = exceptions.map(_.getMessage).toBag
          log.debug(s"Got ${exceptions.size} exceptions with messages: $messagesCount")
        }

        // Fallback to old photos if all new photos failed
        if (result.nonEmpty) {
          result
        } else {
          oldPhotos
        }
      }
    } else {
      Future.successful(oldPhotos)
    }
  }
}
