Skip to content

Commit a0837c2

Browse files
authored
Merge pull request scala#7268 from joshlemer/hashmap-opt-transform
Optimize s.c.i.HashMap#transform
2 parents a2b0d26 + 08615a1 commit a0837c2

File tree

3 files changed

+82
-9
lines changed

3 files changed

+82
-9
lines changed

src/library/scala/collection/immutable/ChampHashMap.scala

+62-9
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,12 @@ final class HashMap[K, +V] private[immutable] (private[immutable] val rootNode:
154154
}
155155
}
156156

157+
override def transform[W](f: (K, V) => W) = {
158+
val transformed = rootNode.transform(f)
159+
if (transformed eq rootNode) this.asInstanceOf[HashMap[K, W]]
160+
else new HashMap(transformed, cachedJavaKeySetHashCode)
161+
}
162+
157163
override def filterImpl(pred: ((K, V)) => Boolean, flipped: Boolean): HashMap[K, V] = {
158164
// This method has been preemptively overridden in order to ensure that an optimizing implementation may be included
159165
// in a minor release without breaking binary compatibility.
@@ -173,15 +179,6 @@ final class HashMap[K, +V] private[immutable] (private[immutable] val rootNode:
173179
super.removeAll(keys)
174180
}
175181

176-
override def transform[W](f: (K, V) => W): HashMap[K, W] = {
177-
// This method has been preemptively overridden in order to ensure that an optimizing implementation may be included
178-
// in a minor release without breaking binary compatibility.
179-
//
180-
// In particular, `transform` could be optimized to traverse the trie node-by-node, swapping out the values of each
181-
// key with the result of applying `f`.
182-
super.transform(f)
183-
}
184-
185182
override def partition(p: ((K, V)) => Boolean): (HashMap[K, V], HashMap[K, V]) = {
186183
// This method has been preemptively overridden in order to ensure that an optimizing implementation may be included
187184
// in a minor release without breaking binary compatibility.
@@ -255,6 +252,7 @@ final class HashMap[K, +V] private[immutable] (private[immutable] val rootNode:
255252
// checks.
256253
super.span(p)
257254
}
255+
258256
}
259257

260258
private[immutable] object MapNode {
@@ -303,6 +301,8 @@ private[immutable] sealed abstract class MapNode[K, +V] extends Node[MapNode[K,
303301

304302
def foreach[U](f: ((K, V)) => U): Unit
305303

304+
def transform[W](f: (K, V) => W): MapNode[K, W]
305+
306306
def copy(): MapNode[K, V]
307307
}
308308

@@ -655,6 +655,44 @@ private final class BitmapIndexedMapNode[K, +V](
655655
}
656656
}
657657

658+
override def transform[W](f: (K, V) => W): BitmapIndexedMapNode[K, W] = {
659+
var newContent: Array[Any] = null
660+
val _payloadArity = payloadArity
661+
val _nodeArity = nodeArity
662+
val newContentLength = content.length
663+
var i = 0
664+
while (i < _payloadArity) {
665+
val key = getKey(i)
666+
val value = getValue(i)
667+
val newValue = f(key, value)
668+
if (newContent eq null) {
669+
if (newValue.asInstanceOf[AnyRef] ne value.asInstanceOf[AnyRef]) {
670+
newContent = content.clone()
671+
newContent(TupleLength * i + 1) = newValue
672+
}
673+
} else {
674+
newContent(TupleLength * i + 1) = newValue
675+
}
676+
i += 1
677+
}
678+
679+
var j = 0
680+
while (j < _nodeArity) {
681+
val node = getNode(j)
682+
val newNode = node.transform(f)
683+
if (newContent eq null) {
684+
if (newNode ne node) {
685+
newContent = content.clone()
686+
newContent(newContentLength - j - 1) = newNode
687+
}
688+
} else
689+
newContent(newContentLength - j - 1) = newNode
690+
j += 1
691+
}
692+
if (newContent eq null) this.asInstanceOf[BitmapIndexedMapNode[K, W]]
693+
else new BitmapIndexedMapNode[K, W](dataMap, nodeMap, newContent, originalHashes, size)
694+
}
695+
658696
override def equals(that: Any): Boolean =
659697
that match {
660698
case node: BitmapIndexedMapNode[K, V] =>
@@ -801,6 +839,21 @@ private final class HashCollisionMapNode[K, +V ](
801839

802840
def foreach[U](f: ((K, V)) => U): Unit = content.foreach(f)
803841

842+
override def transform[W](f: (K, V) => W): HashCollisionMapNode[K, W] = {
843+
val newContent = Vector.newBuilder[(K, W)]
844+
val contentIter = content.iterator
845+
// true if any values have been transformed to a different value via `f`
846+
var anyChanges = false
847+
while(contentIter.hasNext) {
848+
val (k, v) = contentIter.next()
849+
val newValue = f(k, v)
850+
newContent.addOne((k, newValue))
851+
anyChanges ||= (v.asInstanceOf[AnyRef] ne newValue.asInstanceOf[AnyRef])
852+
}
853+
if (anyChanges) new HashCollisionMapNode(originalHash, hash, newContent.result())
854+
else this.asInstanceOf[HashCollisionMapNode[K, W]]
855+
}
856+
804857
override def equals(that: Any): Boolean =
805858
that match {
806859
case node: HashCollisionMapNode[K, V] =>

test/junit/scala/collection/immutable/HashMapTest.scala

+13
Original file line numberDiff line numberDiff line change
@@ -112,5 +112,18 @@ class HashMapTest {
112112
val expected = OldHashMap(A(0) -> 1, A(1) -> 1)
113113
assertEquals(merged, expected)
114114
}
115+
@Test
116+
def transformReturnsOriginalMap() {
117+
case class A(i: Int, j: Int) { override def hashCode = j }
118+
119+
val hashMap = HashMap(
120+
A(1, 1) -> 1,
121+
A(1, 2) -> 1,
122+
A(2, 1) -> 1,
123+
A(2, 2) -> 1
124+
)
125+
126+
assert(hashMap.transform((_, v) => v) eq hashMap)
127+
}
115128
}
116129

test/scalacheck/scala/collection/immutable/ImmutableChampHashMapProperties.scala

+7
Original file line numberDiff line numberDiff line change
@@ -124,4 +124,11 @@ object ImmutableChampHashMapProperties extends Properties("immutable.HashMap") {
124124
val hmb = HashMap.newBuilder[K, V].addAll(xs)
125125
(mb.result() eq mb.result()) && (hmb.result() eq hmb.result())
126126
}
127+
128+
property("transform(f) == map { (k, v) => (k, f(k, v)) }") = forAll { (xs: HashMap[K, V], f: (K, V) => String) =>
129+
xs.transform(f) == xs.map{ case (k, v) => (k, f(k, v)) }
130+
}
131+
property("xs.transform((_, v) => v) eq xs") = forAll { xs: HashMap[K, V] =>
132+
xs.transform((_, v) => v) eq xs
133+
}
127134
}

0 commit comments

Comments
 (0)