package ru.yandex.tours.util

import java.util.concurrent.ThreadLocalRandom

import scala.collection.mutable
import scala.util.Random

/**
 * Author: Vladislav Dolbilov (darl@yandex-team.ru)
 * Created: 06.05.15
 */
trait Randoms {
  def random: Random = ThreadLocalRandom.current()

  def nextString(len: Int): String = Random.alphanumeric.take(len).mkString

  implicit class RandomInt[T](i: Int) {
    def withDeviation(percent: Int): Int = {
      val diff = i.toDouble * percent / 100 * random.nextGaussian()
      i + diff.toInt
    }
  }

  implicit class RandomIterator[T](it: Iterator[T]) {
    def sample(f: Double): Seq[T] = {
      require(f >= 0d && f <= 1d, "fraction should be in range [0.0; 1.0]")

      val rnd = random
      if (f >= 0.4) it.filter(_ => rnd.nextDouble() <= f).toVector
      else {
        val lnq = scala.math.log1p(-f)
        val res = mutable.ArrayBuffer.empty[T]
        while (it.hasNext) {
          val u = scala.math.max(rnd.nextDouble(), 1e-6)
          val toDrop = (scala.math.log(u) / lnq).toInt
          var j = 0
          while (j < toDrop && it.hasNext) {
            it.next()
            j += 1
          }
          if (it.hasNext) {
            res += it.next()
          }
        }
        res.result()
      }
    }
    def sampleIt(sampleSize: Int, totalSize: Int): Iterator[T] = {
      if (sampleSize >= totalSize) it
      else sampleIt(sampleSize.toDouble / totalSize)
    }

    def sampleIt(f: Double): Iterator[T] = {
      require(f >= 0d && f <= 1d, "fraction should be in range [0.0; 1.0]")

      val rnd = random
      if (f >= 0.4) it.filter(_ => rnd.nextDouble() <= f)
      else {
        new Iterator[T] {
          private val lnq = scala.math.log1p(-f)
          private var n: Option[T] = None
          override def hasNext: Boolean = {
            if (n.isDefined) true
            else if (it.hasNext) {
              val u = scala.math.max(rnd.nextDouble(), 1e-6)
              val toDrop = (scala.math.log(u) / lnq).toInt
              var j = 0
              while (j < toDrop && it.hasNext) {
                it.next()
                j += 1
              }
              if (it.hasNext) {
                n = Some(it.next())
              }
              n.isDefined
            } else {
              false
            }
          }
          override def next(): T = {
            if (!hasNext) sys.error("Iterator.empty.next")
            n match {
              case Some(next) =>
                n = None
                next
              case None =>
                sys.error("Unexpected")
            }
          }
        }
      }
    }
  }

  implicit class RandomIterable[T](it: Iterable[T]) {
    def randomElement: T = {
      require(it.nonEmpty, "collection should not be empty")
      val i = random.nextInt(it.size)
      it.slice(i, i + 1).head
    }
    def sample(count: Int): Iterable[T] = {
      random.shuffle(it).take(count)
    }
    def sample(): Iterable[T] = {
      random.shuffle(it).take(random.nextInt(it.size) + 1)
    }
  }
}

object Randoms extends Randoms
