package ru.yandex.tours.db.geomapping

import ru.yandex.extdata.common.meta.DataType
import ru.yandex.tours.db.{DBWrapper, Transactions}
import ru.yandex.tours.extdata.DataTypes
import ru.yandex.tours.geo.base.region
import ru.yandex.tours.geo.mapping.GeoMappingShort
import ru.yandex.tours.model.hotels.Partners._
import slick.driver.MySQLDriver.api._
import ru.yandex.tours.util.lang.Futures._

import scala.concurrent.{ExecutionContext, Future}

/**
  * Created by asoboll on 10.12.15.
  */
class DbGeoMappings(db: DBWrapper)(implicit ex: ExecutionContext) {
  private val dataTypes = Iterable(DataTypes.countries, DataTypes.cities, DataTypes.departures, DataTypes.airports)

  for(dataType <- Seq(DataTypes.countries, DataTypes.cities, DataTypes.departures, DataTypes.airports))
    db.createIfNotExists(GeoMappingTables.getQuery(dataType))

  private def getBanStatus(geoId: Int, regionTree: region.Tree): Future[BanStatus] = {
    val regionParents = regionTree.pathToRoot(geoId).map(_.id)
    def query(dataType: DataType) = db.run(GeoMappingTables.queryBanned(dataType, regionParents).result)
    val bannedQueries = dataTypes.map(query).toSet
    for {
      bannedSets <- Future.sequence(bannedQueries)
      bannedIds = bannedSets.flatMap(GeoMappingRecordProcessor.getBanned)
      bannedCollisions = bannedIds.intersect(regionParents.toSet)
      banStatus = if (bannedIds contains geoId)
        BannedThis
      else if (bannedCollisions.isEmpty)
        NotBanned
      else
        BannedAt(bannedCollisions.head)
    } yield banStatus
  }

  def getRegionMapping(geoId: Int, regionTree: region.Tree): Future[RegionGeoMappings] = {
    for {
      banStatus <- getBanStatus(geoId, regionTree).logTiming("load_ban_status")
      countryMap <- getMappingByType(DataTypes.countries, geoId).logTiming("load_country_map")
      cityMap <- getMappingByType(DataTypes.cities, geoId).logTiming("load_city_map")
      departureMap <- getMappingByType(DataTypes.departures, geoId).logTiming("load_departure_map")
      airportMap <- getMappingByType(DataTypes.airports, geoId).logTiming("load_airport_map")
    } yield RegionGeoMappings(geoId, banStatus, countryMap, cityMap, departureMap, airportMap)
  }

  def update(regionMappings: RegionGeoMappings, regionTree: region.Tree, uid: Long) = {
    Transactions.withTransaction(db, uid) { transaction =>
      for {
        oldMappings <- getRegionMapping(regionMappings.geoId, regionTree)
        banPatch = regionMappings.banPatch(oldMappings)
        listFuturePuts = {
          val puts = for {
            dataType <- dataTypes.toSeq
            patch = regionMappings.diffPatch(oldMappings, dataType)
          } yield GeoMappingTables.put(db, dataType, transaction, patch)
          puts :+ GeoMappingTables.put(db, DataTypes.countries, transaction, banPatch)
        }
        updateAction <- Future.sequence(listFuturePuts)
      } yield updateAction
    }
  }

  private def getMappingByType(dataType: DataType, geoId: Int): Future[Map[Partner, String]] = {
    for {
      result <- db.run(GeoMappingTables.queryByGeoId(dataType, geoId).result)
      mapping = GeoMappingRecordProcessor.getMapping(result.map(GeoMappingRecordExt.tupled)).map {
        case GeoMappingShort(partner, yaId, partnerId) => partner -> partnerId
      }
    } yield GeoMappingRecordProcessor.combineDuplicates(mapping)
  }
}