package ru.yandex.tours.util.collections

import scala.collection.mutable

abstract class DisjointSetWithPayload[T](expectedSize: Int = 1000 * 1000) extends DisjointSet(expectedSize) {
  private val map = mutable.Map.empty[Int, T]

  def getPayload(x: Int): T = {
    map.getOrElse(get(x), initValue(x))
  }

  override def join(x: Int, y: Int): Unit = {
    val xT = getPayload(x)
    val yT = getPayload(y)
    val result = merge(xT, yT)
    map.remove(x)
    map.remove(y)
    super.join(x, y)
    map.put(get(x), result)
  }

  protected def initValue(x: Int): T

  protected def merge(a: T, b: T): T
}
