package ru.yandex.tours.util

import java.io.Closeable

import com.google.common.collect.Iterators
import ru.yandex.tours.util.collections.Bag

import scala.collection.JavaConverters._
import scala.collection.generic.CanBuildFrom
import scala.collection.mutable
import scala.language.higherKinds
import scala.util.control.NonFatal

/**
 * Author: Vladislav Dolbilov (darl@yandex-team.ru)
 * Created: 13.01.15
 */
trait Collections {

  implicit class RichIterable[T, C[X] <: Iterable[X]](seq: C[T]) {

    def swap[A, B](implicit ev: <:<[T, (A, B)]): Iterable[(B, A)] = {
      seq.map { e =>
        val (a, b) = e.asInstanceOf[(A, B)]
        (b, a)
      }
    }

    /** converts `Iterable[(A, B)]` to `Map[A, List[B] ]`*/
    def toMultiMap[K, V](implicit ev: <:<[T, (K, V)]): Map[K, List[V]] = {
      val grouped = seq.asInstanceOf[C[(K, V)]].groupBy(_._1)
      for ((k, v) <- grouped)
      yield k -> v.map(_._2).toList
    }

    def distinctBy[K](keyExtractor: T => K)(implicit cbf: CanBuildFrom[C[_], T, C[T]]): C[T] = {
      val b = cbf.apply()
      val seen = mutable.HashSet.empty[K]
      for (x <- seq) {
        val key = keyExtractor(x)
        if (!seen(key)) {
          seen += key
          b += x
        }
      }

      b.result()
    }

    /** returns pair with minimal and maximal value from collection */
    def minMax[B >: T](implicit cmp: Ordering[B]): (T, T) = {
      if (seq.isEmpty) throw new UnsupportedOperationException("empty.minMax")

      var min, max = seq.head
      for (v <- seq) {
        if (cmp.gt(min, v)) min = v
        if (cmp.lt(max, v)) max = v
      }
      (min, max)
    }

    def minOpt(implicit cmp: Ordering[T]): Option[T] = {
      if (seq.isEmpty) None
      else Some(seq.min)
    }

    def minOptBy[B](f: T ⇒ B)(implicit cmp: Ordering[B]): Option[T] = {
      if (seq.isEmpty) None
      else Some(seq.minBy(f))
    }

    def toBag: Bag[T] = {
      val bag = new Bag[T]()
      seq.foreach(bag += _)
      bag
    }
  }

  implicit class RichMap[K, V, C[X, Y] <: Map[X, Y]](map: C[K, V]) {
    def swapKV: Map[V, K] = {
      val reverse = new mutable.HashMap[V, K]()
      for ((key, value) <- map) {
        reverse += (value -> key)
      }
      reverse.toMap
    }
    def inverse: Map[V, Seq[K]] = {
      val reverseMapping = new mutable.HashMap[V, mutable.Set[K]] with mutable.MultiMap[V, K]
      for {
        (key, value) <- map
      } {
        reverseMapping.addBinding(value, key)
      }
      reverseMapping.toMap.map(kv => kv._1 -> kv._2.toVector)
    }
    def inverse2[V2](implicit ev: <:<[V, Iterable[V2]]): Map[V2, Seq[K]] = {
      val reverseMapping = new mutable.HashMap[V2, mutable.Set[K]] with mutable.MultiMap[V2, K]
      for {
        (key, value) <- map
        value2 <- value.asInstanceOf[Iterable[V2]]
      } {
        reverseMapping.addBinding(value2, key)
      }
      reverseMapping.toMap.map(kv => kv._1 -> kv._2.toVector)
    }
    def join[V2](map2: Map[K, V2]): Map[K, (Option[V], Option[V2])] = {
      val keys = map.keySet ++ map2.keySet
      keys.map(k => k -> (map.get(k), map2.get(k)))(collection.breakOut)
    }

    def mapValuesStrict[V2, That](f: V ⇒ V2): Map[K, V2] = {
      map.map(kv ⇒ kv._1 → f(kv._2))
    }
  }

  implicit class RichIterator[T, C[X] <: Iterator[X]](it: C[T]) {
    def headOption: Option[T] = {
      if (it.hasNext) Some(it.next)
      else Option.empty
    }
    def onFinish(action: => Unit): Iterator[T] = {
      new Iterator[T] {
        private var finished = false
        override def hasNext: Boolean = {
          if (finished) false
          else {
            val _hasNext = it.hasNext
            if (!_hasNext) {
              finished = true
              action
            }
            _hasNext
          }
        }

        override def next(): T = it.next()
      }
    }

    def onFail(action: => Unit): Iterator[T] = new Iterator[T] {
      override def hasNext: Boolean = try {
        it.hasNext
      } catch {
        case NonFatal(e) => action; throw e
      }

      override def next(): T = try {
        it.next()
      } catch {
        case NonFatal(e) => action; throw e
      }
    }
    def bindTo(closeable: Closeable): Iterator[T] = {
      it.onFail(closeable.close()).onFinish(closeable.close())
    }
  }

  def mergeSorted[T](iterators: Iterable[Iterator[T]])(implicit ord: Ordering[T]): Iterator[T] = {
    Iterators.mergeSorted(
      iterators.map(_.asJava).asJava,
      ord
    ).asScala
  }
}

object Collections extends Collections
