package ru.yandex.tours.db.tables

import ru.yandex.tours.clustering.Clustering.{LinkTrait, LinkWithConfidenceTrait}
import ru.yandex.tours.db.tables.LinkType.LinkType
import ru.yandex.tours.db.{DBWrapper, Transactions}
import ru.yandex.tours.util.graph.BfsWalker
import ru.yandex.tours.util.{Logging, Statistics}
import slick.driver.MySQLDriver.api._

import scala.collection.mutable
import scala.concurrent.{ExecutionContext, Future}
import scala.util.Success

class Clusterization(tag: Tag) extends Table[ClusterLink](tag, "clusterization") {
  def id = column[Int]("id", O.PrimaryKey, O.AutoInc)

  def parent = column[Int]("parent")

  def child = column[Int]("child")

  def transactionId = column[Int]("transaction_id")

  def confidence = column[Double]("confidence")

  def linkType = column[Int]("link_type")

  def childForeignKey = foreignKey("child_fk", child, Hotels.table)(_.id, ForeignKeyAction.Restrict, ForeignKeyAction.Restrict)

  def parentForeignKey = foreignKey("parent_fk", parent, Hotels.table)(_.id, ForeignKeyAction.Restrict, ForeignKeyAction.Restrict)

  def transactionForeignKey = foreignKey("cluster_transaction_fk", transactionId, Transactions.table)(_.id, ForeignKeyAction.Restrict, ForeignKeyAction.Restrict)

  override def * = (id, parent, child, transactionId, confidence, linkType).shaped <>( {
    case (id, parent, child, transactionId, confidence, linkType) =>
      ClusterLink(id, parent, child, transactionId, confidence, LinkType(linkType))
  }, {
    cl: ClusterLink =>
      Some((cl.id, cl.parent, cl.child, cl.transactionId, cl.confidence, cl.`type`.id))
  })
}

object LinkType extends Enumeration {
  type LinkType = Value

  val MERGE = Value(0)
  val UNMERGE = Value(1)
}

case class ClusterLink(id: Int,
                       parent: Int,
                       child: Int,
                       transactionId: Int,
                       confidence: Double = 1d,
                       `type`: LinkType = LinkType.MERGE) extends LinkWithConfidenceTrait

object Clusterization extends Logging {
  val table = TableQuery[Clusterization]
  val defaultMinConfidence = 0.7d

  case class LinkWithInfo(child: Int, parent: Int, confidence: Double, `type`: LinkType,
                          timestamp: Long, author: Long) extends LinkTrait {
    require(child < parent, "`from` should be less than `to`")

    def isManual: Boolean = author > 0
  }

  def retrieveLinks(id1: Int, id2: Int, minConfidence: Double, db: DBWrapper)
                   (implicit ec: ExecutionContext): Future[Seq[LinkWithInfo]] = {
    val set = Set(id1, id2)
    val query = baseRequest(minConfidence).filter(t => t._1.inSet(set) && t._2.inSet(set) && (t._1 =!= t._2))
    run(db, query)
  }

  def retrieveClusterLinks(minConfidence: Double, db: DBWrapper, name: String)
                          (implicit ec: ExecutionContext): Future[Seq[LinkWithInfo]] = {
    val query = baseRequest(minConfidence)
    val result = run(db, query)
    Statistics.asyncLogTime(s"$name: Cluster retrieving", result)
      .andThen { case Success(links) ⇒ log.info(s"${links.size} cluster links retrieved from db") }
  }

  def retrieveClusterLinks(db: DBWrapper, id: Int, minConfidence: Double = defaultMinConfidence)
                          (implicit ec: ExecutionContext): Future[Seq[ClusterLink]] = {
    for {
      cluster <- retrieveCluster(db, id, minConfidence)
      links <- Future.traverse(cluster.toSeq) { id =>
        val query = for {
          cluster <- Clusterization.table
          transaction <- Transactions.table
          if cluster.parent === id || cluster.child === id
          if cluster.transactionId === transaction.id
          if cluster.confidence >= minConfidence
          if transaction.isEnabled
        } yield cluster

        db.run(query.result)
      }
    } yield links.flatten.distinct
  }

  def retrieveCluster(db: DBWrapper, id: Int, minConfidence: Double = defaultMinConfidence)
                     (implicit ec: ExecutionContext): Future[Set[Int]] = {
    new BfsWalker(id => retrieveNeighbours(db, id, minConfidence)).getReachableNeighbours(id)
  }

  def insert(db: DBWrapper, links: Iterable[ClusterLink]): Future[Option[Int]] = {
    val q = table ++= links
    db.run(q)
  }

  private def retrieveNeighbours(db: DBWrapper,
                                 id: Int,
                                 minConfidence: Double)(implicit ec: ExecutionContext): Future[Set[Int]] = {
    val query = baseRequest(minConfidence).filter(x => x._1 === id || x._2 === id)
    run(db, query).map { links =>
      val result = mutable.HashSet.empty[Int]
      links.sortBy(l => (l.isManual, l.timestamp)).foreach {
        case LinkWithInfo(from, to, _, linkType, _, _) =>
          val neighbour = if (from == id) to else from
          linkType match {
            case LinkType.MERGE => result += neighbour
            case LinkType.UNMERGE => result -= neighbour
          }
      }
      result.toSet
    }
  }

  private def baseRequest(minConfidence: Double) = {
    for {
      cluster <- Clusterization.table
      transaction <- Transactions.table
      if cluster.transactionId === transaction.id
      if cluster.confidence >= minConfidence
      if transaction.isEnabled
    } yield {
      (cluster.child, cluster.parent, cluster.confidence, cluster.linkType, transaction.timestamp, transaction.author, cluster.id)
    }
  }

  private def buildLink(from: Int, to: Int, confidence: Double, linkType: Int,
                        timestamp: Long, author: Long): LinkWithInfo = {
    LinkWithInfo(Math.min(from, to), Math.max(from, to), confidence, LinkType(linkType), timestamp, author)
  }

  private def buildLink(x: (Int, Int, Double, Int, Long, Long, Int)): LinkWithInfo = {
    buildLink(x._1, x._2, x._3, x._4, x._5, x._6)
  }

  private def run(db: DBWrapper,
                  query: Query[(Rep[Int], Rep[Int], Rep[Double], Rep[Int], Rep[Long], Rep[Long], Rep[Int]), (Int, Int, Double, Int, Long, Long, Int), Seq])
                 (implicit ec: ExecutionContext) = {
    val result = mutable.Buffer[LinkWithInfo]()
    runPart(result, 0, db, query).map(_ => result)
  }

  private def runPart(result: mutable.Buffer[LinkWithInfo],
                      offset: Int, db: DBWrapper,
                      query: Query[(Rep[Int], Rep[Int], Rep[Double], Rep[Int], Rep[Long], Rep[Long], Rep[Int]), (Int, Int, Double, Int, Long, Long, Int), Seq])
                     (implicit ec: ExecutionContext) : Future[Unit] = {
    val limit = 300000
    val q = query.sortBy(_._7).drop(offset).take(limit)
    val qr = q.result
    db.run(qr.transactionally).flatMap(
      rows => {
        result.appendAll(rows.map(buildLink))
        if (rows.size < limit) {
          Future.successful(Unit)
        } else {
          runPart(result, offset + limit, db, query)
        }
      }
    )
  }
}