Skip to content

Add pick generator specialized for indexed sequences #874

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions jvm/src/test/scala/org/scalacheck/GenSpecification.scala
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,30 @@ object GenSpecification extends Properties("Gen") with GenSpecificationVersionSp
}
}

property("indexedPick") = forAll { (vec: Vector[Int]) =>
forAll(choose(-1, 2 * vec.length)) { n =>
Try(indexedPick(n, vec)) match {
case Success(g) =>
forAll(g) { m => m.length == n && m.forall(vec.contains) }
case Failure(_) =>
Prop(n < 0 || n > vec.length)
}
}
}

property("indexedPick does not repeat picks") = forAll { (set: Set[Int]) =>
forAll(choose(-1, 2 * set.size)) { n =>
Try(indexedPick(n, set.toVector)) match {
case Success(g) =>
forAll(g) { m =>
m.toSet.size == n
}
case Failure(_) =>
Prop(n < 0 || n > set.size)
}
}
}

/**
* Expect:
* 25% 1, 2, 3
Expand Down
26 changes: 26 additions & 0 deletions src/main/scala/org/scalacheck/Gen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import language.implicitConversions

import rng.Seed
import util.Buildable
import util.MissingSelector
import util.SerializableCanBuildFroms._
import ScalaVersionSpecific._

Expand Down Expand Up @@ -1142,6 +1143,31 @@ object Gen extends GenArities with GenVersionSpecific {
def pick[T](n: Int, g1: Gen[T], g2: Gen[T], gn: Gen[T]*): Gen[Seq[T]] =
pick(n, g1 +: g2 +: gn).flatMap(sequence[Seq[T], T](_))

/** A generator that randomly picks a given number of elements from an IndexedSeq
*
* The elements are guaranteed to be permuted in random order.
*/
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A quick comment on the runtime improvement over pick would be helpful. Perhaps also that it doesn't repeat elements.

def indexedPick[T](n: Int, l: IndexedSeq[T]): Gen[collection.Seq[T]] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Names that sort close to their relatives improve discoverability: how about pickIndexed?

if (n > l.size || n < 0) throw new IllegalArgumentException(s"invalid choice: $n")
else if (n == 0) Gen.const(Nil)
else gen { (p, seed0) =>
val buf = ArrayBuffer.empty[T]
var seed = seed0
var selector: MissingSelector = MissingSelector.empty
while (buf.size < n) {
val (x, s) = seed.long
// After having chosen k indices, we can choose between (l.size - k) available indices
val idxInAvailable = (x & Long.MaxValue % (l.size - buf.size)).toInt
// Translate from index in available to real index
val (idx, newSelector) = selector.selectAndAdd(idxInAvailable)
selector = newSelector
buf += l(idx)
seed = s
}
r(Some(buf), seed)
}
}

/** Takes a function and returns a generator that generates arbitrary
* results of that function by feeding it with arbitrarily generated input
* parameters. */
Expand Down
76 changes: 76 additions & 0 deletions src/main/scala/org/scalacheck/util/MissingSelector.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package org.scalacheck.util

// Modified red-black order statistic tree to select missing elements.
private[scalacheck] sealed trait MissingSelector {
/** Select (and add) the i-th non-negative integer not present.
* @param i
* @return The i-th non negative integer not present in the original MissingSelector,
* and a new MissingSelector containing this integer.
*/
def selectAndAdd(i: Int): (Int, MissingSelector)
protected def selAndAdd(i: Int): (Int, MissingSelector.Inner)
def size: Int
def toList(others: List[Int] = List.empty[Int]): List[Int]
}

private[scalacheck] object MissingSelector {
val empty: MissingSelector = Empty

sealed trait Color
object B extends Color
object R extends Color

final class Inner private(
private val color: Color,
private val root: Int,
private val left: MissingSelector,
private val right: MissingSelector,
val size: Int
) extends MissingSelector {
override def selectAndAdd(i: Int): (Int, MissingSelector) = {
val (newI, newTree) = selAndAdd(i)
(newI, newTree.asBlack)
}

override protected def selAndAdd(i: Int): (Int, Inner) = {
val (newI, newTree) = if (i + left.size < root) {
val (newI, newLeft) = left.selAndAdd(i)
(newI, Inner(color, root, newLeft, right))
} else {
val (newI, newRight) = right.selAndAdd(i + left.size + 1)
(newI, Inner(color, root, left, newRight))
}
(newI, newTree.balance)
}

override def toList(others: List[Int]): List[Int] = left.toList(root :: right.toList(others))

private def balance: Inner = this match {
case Inner(B, z, Inner(R, y, Inner(R, x, a, b), c), d) =>
Inner(R, y, Inner(B, x, a, b), Inner(B, z, c, d))
case Inner(B, z, Inner(R, x, a, Inner(R, y, b, c)), d) =>
Inner(R, y, Inner(B, x, a, b), Inner(B, z, c, d))
case Inner(B, x, a, Inner(R, z, Inner(R, y, b, c), d)) =>
Inner(R, y, Inner(B, x, a, b), Inner(B, z, c, d))
case Inner(B, x, a, Inner(R, y, b, Inner(R, z, c, d))) =>
Inner(R, y, Inner(B, x, a, b), Inner(B, z, c, d))
case _ => this
}

private def asBlack: Inner = if (color == B) this else Inner(B, root, left, right)
}

object Inner {
private[MissingSelector] def apply(color: Color, root: Int, left: MissingSelector, right: MissingSelector): Inner =
new Inner(color, root, left, right, left.size + right.size + 1)
def unapply(inner: Inner): Option[(Color, Int, MissingSelector, MissingSelector)] =
Some((inner.color, inner.root, inner.left, inner.right))
}

private object Empty extends MissingSelector {
override def selectAndAdd(i: Int): (Int, MissingSelector) = (i, Inner(R, i, Empty, Empty))
override protected def selAndAdd(i: Int): (Int, Inner) = (i, Inner(R, i, Empty, Empty))
override def size: Int = 0
override def toList(others: List[Int]): List[Int] = others
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package org.scalacheck
package util

import org.scalacheck.{Gen, Properties}
import org.scalacheck.Prop.forAll
import org.scalacheck.util.MissingSelector._

object MissingSelectorSpecification extends Properties("MissingSelector") {

private val smallIntegerGen: Gen[Int] = Gen.choose(0,1000)

private val missingSelectorGen: Gen[MissingSelector] = Gen.listOf(smallIntegerGen).map { list =>
list.foldLeft(MissingSelector.empty){ case (selector, elem) => selector.selectAndAdd(elem)._2 }
}

property("selectAndAdd adds the selected element to the selector") =
forAll(missingSelectorGen) { selector =>
forAll(smallIntegerGen) { i =>
val (selected, newSelector) = selector.selectAndAdd(i)
newSelector.toList().sorted == (selected :: selector.toList()).sorted
}
}

property("selectAndAdd selects the i-th missing element") =
forAll(missingSelectorGen) { selector =>
forAll(smallIntegerGen) { i =>
val (selected, _) = selector.selectAndAdd(i)
val numNotGreater = selector.toList().filter(_ <= selected).length
i + numNotGreater == selected
}
}

// Red black tree invariants
property("no red node has a red child") = {
def redRed(selector: MissingSelector): Boolean = selector match {
case Inner(R, _, Inner(R, _, _, _), _) | Inner(R, _, _, Inner(R, _, _, _)) => true
case Inner(_, _, left, right) => redRed(left) || redRed(right)
case _ => false
}
forAll(missingSelectorGen)(selector => !redRed(selector))
}

property("all paths root-leaf have the same number of black nodes") = {
def checkBlacks(selector: MissingSelector): Either[Unit, Int] = selector match {
case Inner(color, _, left, right) =>
val numRoot = color match {
case B => 1
case R => 0
}
for {
numLeft <- checkBlacks(left)
numRight <- checkBlacks(right)
res <- if (numLeft == numRight) Right(numRoot + numLeft) else Left(())
} yield res
case _ => Right(1)
}
forAll(missingSelectorGen)(selector => checkBlacks(selector).isRight)
}
}