package ru.yandex.tours.indexer.clusterization

import ru.yandex.tours.clustering.Clustering.LinkWithConfidence
import ru.yandex.tours.db.tables.Clusterization.LinkWithInfo
import ru.yandex.tours.db.tables.LinkType
import ru.yandex.tours.model.hotels.Partners.Partner
import ru.yandex.tours.util.Logging
import ru.yandex.tours.util.collections.{Bag, DisjointSetWithPayload, Graph}

import scala.collection.mutable

class HotelLinkCleaner(existLinks: Seq[LinkWithInfo], hotelToPartner: Map[Int, Partner]) extends Logging {

  private val (disjointSet, banned) = {
    val merged = mutable.Set.empty[(Int, Int)]
    val unmerged = mutable.Set.empty[(Int, Int)]

    existLinks.sortBy(l => (l.isManual, l.timestamp)).foreach {
      case LinkWithInfo(child, parent, _, linkType, _, _) =>
        linkType match {
          case LinkType.MERGE =>
            merged += child -> parent
            unmerged -= child -> parent
          case LinkType.UNMERGE =>
            merged -= child -> parent
            unmerged += child -> parent
        }
    }
    val graph = new Graph(merged)
    val disjointSet = createEmptyDisjointSet

    graph.getConnectedComponents.foreach { cluster =>
      val id1 = cluster.head
      for (id2 <- cluster.tail) {
        disjointSet.join(id1, id2)
      }
    }
    val banned = unmerged.map {
      case (id1, id2) => disjointSet.get(id1) -> disjointSet.get(id2)
    }.toSet
    (disjointSet, banned)
  }

  def removeExcessLinksAndSetConfidence(links: Seq[LinkWithConfidence]): Seq[LinkWithConfidence] = this.synchronized {
    links.sortBy(-_.confidence).flatMap { link =>
      val x = disjointSet.get(link.child)
      val y = disjointSet.get(link.parent)
      if (x == y) {
        log.debug(s"Removing $link: hotels in same cluster")
        None
      } else if (banned.contains(x -> y) || banned.contains(y -> x)) {
        log.info(s"Removing $link: link joins unmerged clusters")
        None
      } else {
        val partnersInResultCluster = Bag.merge(
          disjointSet.getPayload(link.child),
          disjointSet.getPayload(link.parent)
        ).filterByCount(_ > 2)
        if (partnersInResultCluster.nonEmpty) {
          if (partnersInResultCluster.toMap.exists(_._2 >= 5)) {
            log.info(s"Removing $link: too many partners in cluster $partnersInResultCluster")
            None
          } else {
            log.info(s"Reduced confidence for $link: too many partners in cluster $partnersInResultCluster")
            disjointSet.join(link.child, link.parent)
            Some(link.copy(confidence = link.confidence / 2))
          }
        } else {
          disjointSet.join(link.child, link.parent)
          Some(link)
        }
      }
    }
  }

  private def createEmptyDisjointSet: DisjointSetWithPayload[Bag[Partner]] = {
    new DisjointSetWithPayload[Bag[Partner]]() {
      override protected def initValue(x: Int): Bag[Partner] = {
        val bag = new Bag[Partner]
        hotelToPartner.get(x).foreach { p => bag += p }
        bag
      }

      override protected def merge(a: Bag[Partner], b: Bag[Partner]): Bag[Partner] = {
        Bag.merge(a, b)
      }
    }
  }
}
