Skip to content

Commit 402c93f

Browse files
authored
Merge pull request scala#7124 from joshlemer/tap-each
[11098] Add IterableOps#tapEach method
2 parents 04f79e7 + fe1c6a1 commit 402c93f

File tree

9 files changed

+137
-4
lines changed

9 files changed

+137
-4
lines changed

src/library/scala/collection/Iterable.scala

+2
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,8 @@ trait IterableOps[+A, +CC[_], +C] extends Any with IterableOnce[A] with Iterable
819819
*/
820820
def inits: Iterator[C] = iterateUntilEmpty(_.init)
821821

822+
override def tapEach[U](f: A => U): C = fromSpecific(new View.Map(this, { a: A => f(a); a }))
823+
822824
// A helper for tails and inits.
823825
private[this] def iterateUntilEmpty(f: Iterable[A] => Iterable[A]): Iterator[C] = {
824826
val it = Iterator.iterate(toIterable)(f).takeWhile(x => !x.isEmpty)

src/library/scala/collection/IterableOnce.scala

+11
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,17 @@ trait IterableOnceOps[+A, +CC[_], +C] extends Any { this: IterableOnce[A] =>
458458
*/
459459
def span(p: A => Boolean): (C, C)
460460

461+
/** Applies a side-effecting function to each element in this collection.
462+
* Strict collections will apply `f` to their elements immediately, while lazy collections
463+
* like Views and LazyLists will only apply `f` on each element if and when that element
464+
* is evaluated, and each time that element is evaluated.
465+
*
466+
* @param f a function to apply to each element in this $coll
467+
* @tparam U the return type of f
468+
* @return The same logical collection as this
469+
*/
470+
def tapEach[U](f: A => U): C
471+
461472
/////////////////////////////////////////////////////////////// Concrete methods based on iterator
462473

463474
def knownSize: Int = -1

src/library/scala/collection/Iterator.scala

+10
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,16 @@ trait Iterator[+A] extends IterableOnce[A] with IterableOnceOps[A, Iterator, Ite
863863
}
864864
}
865865

866+
override def tapEach[U](f: A => U): Iterator[A] = new AbstractIterator[A] {
867+
override def knownSize = self.knownSize
868+
override def hasNext = self.hasNext
869+
override def next() = {
870+
val _next = self.next()
871+
f(_next)
872+
_next
873+
}
874+
}
875+
866876
/** Converts this iterator to a string.
867877
*
868878
* @return `"<iterator>"`

src/library/scala/collection/StrictOptimizedIterableOps.scala

+6
Original file line numberDiff line numberDiff line change
@@ -249,4 +249,10 @@ trait StrictOptimizedIterableOps[+A, +CC[_], +C]
249249
(l.result(), r.result())
250250
}
251251

252+
// Optimization avoids creation of second collection
253+
override def tapEach[U](f: A => U): C = {
254+
foreach(f)
255+
coll
256+
}
257+
252258
}

src/library/scala/collection/immutable/LazyList.scala

+2
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,8 @@ final class LazyList[+A] private(private[this] var lazyState: () => LazyList.Sta
380380
if (knownIsEmpty) LazyList.empty
381381
else (mapImpl(f): @inline)
382382

383+
override def tapEach[U](f: A => U): LazyList[A] = map { a => f(a); a}
384+
383385
private def mapImpl[B](f: A => B): LazyList[B] =
384386
newLL {
385387
if (isEmpty) State.Empty

test/junit/scala/collection/IteratorTest.scala

+34-2
Original file line numberDiff line numberDiff line change
@@ -535,24 +535,33 @@ class IteratorTest {
535535
// Avoid reaching seq1 through test class. Avoid testing Array.iterator.
536536
class C extends Iterable[String] {
537537
val ss = Array("first", "second")
538+
538539
def iterator = new Iterator[String] {
539540
var i = 0
541+
540542
def hasNext = i < ss.length
543+
541544
def next() =
542-
if (hasNext) { val res = ss(i) ; i += 1 ; res }
545+
if (hasNext) {
546+
val res = ss(i); i += 1; res
547+
}
543548
else Iterator.empty.next()
544549
}
550+
545551
def apply(i: Int) = ss(i)
546552
}
547553
val seq1 = new WeakReference(new C)
548554
val seq2 = List("third")
549555
val it0: Iterator[Int] = Iterator(1, 2)
550556
lazy val it: Iterator[String] = it0.flatMap {
551557
case 1 => seq1.get
552-
case _ => check() ; seq2
558+
case _ => check(); seq2
553559
}
560+
554561
def check() = assertNotReachable(seq1.get, it)(())
562+
555563
def checkHasElement() = assertNotReachable(seq1.get.apply(1), it)(())
564+
556565
assert(it.hasNext)
557566
assertEquals("first", it.next())
558567

@@ -568,4 +577,27 @@ class IteratorTest {
568577
}
569578
assert(!it.hasNext)
570579
}
580+
581+
@Test def tapEach(): Unit = {
582+
locally {
583+
var i = 0
584+
val tapped = Iterator(-1, -1, -1).tapEach(_ => i += 1)
585+
assertEquals(true, tapped.hasNext)
586+
assertEquals(0, i)
587+
}
588+
589+
locally {
590+
var i = 0
591+
val tapped = Iterator(-1, -1, -1).tapEach(_ => i += 1)
592+
assertEquals(-3, tapped.sum)
593+
assertEquals(3, i)
594+
}
595+
596+
locally {
597+
var i = 0
598+
val tapped = Iterator(-1, -1, -1).tapEach(_ => i += 1)
599+
assertEquals(-1, tapped.next())
600+
assertEquals(1, i)
601+
}
602+
}
571603
}

test/junit/scala/collection/ViewTest.scala

+22-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
package scala.collection
22

33
import scala.collection.immutable.List
4-
54
import org.junit.Assert._
65
import org.junit.Test
76
import org.junit.runner.RunWith
87
import org.junit.runners.JUnit4
8+
99
import language.postfixOps
10+
import scala.collection.mutable.ListBuffer
1011

1112
@RunWith(classOf[JUnit4])
1213
class ViewTest {
@@ -17,7 +18,8 @@ class ViewTest {
1718

1819
import scala.language.postfixOps
1920
assertEquals(Iterable.empty[Int], iter.view take Int.MinValue to Iterable)
20-
assertEquals(Iterable.empty[Int], iter.view takeRight Int.MinValue to Iterable)
21+
assertEquals(Iterable.empty[Int],
22+
iter.view takeRight Int.MinValue to Iterable)
2123
assertEquals(iter, iter.view drop Int.MinValue to Iterable)
2224
assertEquals(iter, iter.view dropRight Int.MinValue to Iterable)
2325
}
@@ -69,4 +71,22 @@ class ViewTest {
6971
check(immutable.Map(1 -> "a", 2 -> "b")) // MapView
7072
}
7173

74+
@Test
75+
def tapEach: Unit = {
76+
val lb = ListBuffer[Int]()
77+
78+
val v =
79+
View(1, 2, 3)
80+
.tapEach(lb += _)
81+
.map(_ => 10)
82+
.tapEach(lb += _)
83+
.tapEach(_ => lb += -1)
84+
85+
assertEquals(ListBuffer[Int](), lb)
86+
87+
val strict = v.to(Seq)
88+
assertEquals(strict, Seq(10, 10, 10))
89+
assertEquals(lb, Seq(1, 10, -1, 2, 10, -1, 3, 10, -1))
90+
}
91+
7292
}

test/junit/scala/collection/immutable/LazyListTest.scala

+30
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import org.junit.Test
66
import org.junit.Assert._
77

88
import scala.collection.Iterator
9+
import scala.collection.mutable.ListBuffer
910
import scala.ref.WeakReference
1011
import scala.util.Try
1112

@@ -358,6 +359,15 @@ class LazyListTest {
358359
assertLazyHeadWhenNextHeadEvaluated(op)
359360
}
360361

362+
@Test
363+
def tapEach_properlyLazy(): Unit = {
364+
val op = lazyListOp(_.tapEach(_ + 1))
365+
assertRepeatedlyFullyLazy(op)
366+
assertLazyTailWhenHeadEvaluated(op)
367+
assertLazyHeadWhenTailEvaluated(op)
368+
assertLazyHeadWhenNextHeadEvaluated(op)
369+
}
370+
361371
@Test
362372
def collect_properlyLazy(): Unit = {
363373
val op = lazyListOp(_ collect { case i if i % 2 != 0 => i})
@@ -1000,4 +1010,24 @@ class LazyListTest {
10001010
assertEquals(s"mkString 3 $i", goal(3,i), precyc(3,i).mkString)
10011011
}
10021012
}
1013+
@Test
1014+
def tapEach: Unit = {
1015+
1016+
/** @param makeLL must make a lazylist that evaluates to Seq(1,2,3,4,5) */
1017+
def check(makeLL: => LazyList[Int]): Unit = {
1018+
val lb = ListBuffer[Int]()
1019+
val ll = makeLL.tapEach(lb += _)
1020+
assertEquals(ListBuffer[Int](), lb)
1021+
assertEquals(Vector(1, 2), ll.take(2).to(Vector))
1022+
assertEquals(ListBuffer(1, 2), lb)
1023+
assertEquals(4, ll(3))
1024+
assertEquals(ListBuffer(1, 2, 4), lb)
1025+
assertEquals(Vector(1,2,3,4,5), ll.to(Vector))
1026+
assertEquals(ListBuffer(1, 2, 4, 3, 5), lb)
1027+
}
1028+
1029+
check(LazyList.from(Iterator(1, 2, 3, 4, 5)))
1030+
check(LazyList.from(Vector(1, 2, 3, 4, 5)))
1031+
check(LazyList.tabulate(5)(_ + 1))
1032+
}
10031033
}

test/junit/scala/collection/immutable/VectorTest.scala

+20
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import org.junit.runner.RunWith
55
import org.junit.runners.JUnit4
66
import org.junit.Test
77

8+
import scala.collection.mutable.ListBuffer
9+
810
@RunWith(classOf[JUnit4])
911
class VectorTest {
1012

@@ -161,4 +163,22 @@ class VectorTest {
161163
assertEquals(-1, test.indexOf(1000))
162164
}
163165

166+
@Test
167+
def tapEach(): Unit = {
168+
val lb = ListBuffer[Int]()
169+
170+
val v =
171+
Vector(1,2,3)
172+
.tapEach(lb += _)
173+
.tapEach(lb += _)
174+
175+
assertEquals(ListBuffer(1,2,3,1,2,3), lb)
176+
assertEquals(Vector(1,2,3), v)
177+
178+
179+
val f: Any => Unit = println
180+
181+
// test that type info is not lost
182+
val x: Vector[Char] = Vector[Char]().tapEach(f)
183+
}
164184
}

0 commit comments

Comments
 (0)