Skip to content

Commit

Permalink
Merge pull request #3382 from zainab-ali/ensure
Browse files Browse the repository at this point in the history
Add chunk-preserving ensure function.
  • Loading branch information
mpilquist authored Feb 3, 2024
2 parents 49610a8 + fef93a4 commit 95eeb84
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
32 changes: 32 additions & 0 deletions core/shared/src/main/scala/fs2/Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,22 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
queue: Queue[F2, Option[Chunk[O2]]]
): Stream[F2, Nothing] = enqueueNoneTerminatedChunks(queue: QueueSink[F2, Option[Chunk[O2]]])

/** Emits the longest prefix of the input for which all elements test true. Raises an error if an element tests false.
*
* @example {{{
* scala> Stream(1, 2, 3, 4).ensure[Fallible](new RuntimeException)(_ != 3).toList
* res0: Either[Throwable,List[Int]] = Left(java.lang.RuntimeException)
* scala> Stream(1, 2, 3, 4).ensure[Fallible](new RuntimeException)(_ != 5).toList
* res0: Either[Throwable,List[Int]] = Right(List(1, 2, 3, 4))
* scala> Stream(1, 2, 3, 4).ensure[Fallible](new RuntimeException)(_ != 3).attempt.toList
* res0: Either[Throwable,List[Either[Throwable, Int]]] = Right(List(Right(1), Right(2), Left(java.lang.RuntimeException)))
* }}}
*/
def ensure[F2[x] >: F[x]](e: Throwable)(p: O => Boolean)(implicit
ev: RaiseThrowable[F2]
): Stream[F2, O] =
this.covary[F2].pull.ensure(e)(p).stream

/** Alias for `flatMap(o => Stream.eval(f(o)))`.
*
* @example {{{
Expand Down Expand Up @@ -4754,6 +4770,15 @@ object Stream extends StreamLowPriority {
case Some((hd, tl)) => Pull.output(hd).as(Some(tl))
}

/** Like `[[takeWhile]]`, but raises an error if an element tests false. */
def ensure(
error: => Throwable
)(predicate: O => Boolean)(implicit F: RaiseThrowable[F]): Pull[F, O, Unit] =
takeWhile_(predicate, takeFailure = false).flatMap {
case None => Pull.done
case Some(_) => Pull.raiseError(error)
}

/** Like `[[unconsN]]`, but leaves the buffered input unconsumed. */
def fetchN(n: Int): Pull[F, Nothing, Option[Stream[F, O]]] =
unconsN(n).map(_.map { case (hd, tl) => tl.cons(hd) })
Expand Down Expand Up @@ -5472,6 +5497,13 @@ object Stream extends StreamLowPriority {
def handleErrorWith[A](s: Stream[F, A])(h: Throwable => Stream[F, A]) =
s.handleErrorWith(h)
def raiseError[A](t: Throwable) = Stream.raiseError[F](t)
override def attempt[A](s: Stream[F, A]): Stream[F, Either[Throwable, A]] = s.attempt
override def rethrow[A, EE <: Throwable](s: Stream[F, Either[EE, A]]): Stream[F, A] =
s.rethrow
override def ensure[A](s: Stream[F, A])(error: => Throwable)(
predicate: A => Boolean
): Stream[F, A] =
s.ensure(error)(predicate)
}

/** `Monoid` instance for `Stream`. */
Expand Down
20 changes: 20 additions & 0 deletions core/shared/src/test/scala/fs2/StreamSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,26 @@ class StreamSuite extends Fs2Suite {
}
}

group("ensure") {
property("preserves chunks") {
forAll { (s: Stream[Pure, Int]) =>
val s1 = s.covary[Fallible].chunks
val s2 = s.covary[Fallible].ensure(new Err)(_ => true).chunks
assertEquals(s1.toList, s2.toList)
}
}
test("fails when predicate fails") {
val err = new Err
val s = Stream(1, 2, 3).ensure[Fallible](err)(_ != 2).attempt
assertEquals(s.toList, Right(List(Right(1), Left(err))))
}
test("succeeds when predicate succeeds") {
val err = new Err
val s = Stream(1, 2, 3).ensure[Fallible](err)(_ != 10)
assertEquals(s.toList, s.covary[Fallible].toList)
}
}

test("eval") {
Stream.eval(SyncIO(23)).compile.lastOrError.assertEquals(23)
}
Expand Down

0 comments on commit 95eeb84

Please sign in to comment.