Skip to content

Commit c9b2ef6

Browse files
committed
Merge pull request scala#1298 from pavelpavlov/SI-5767
SI-5767 fix + protecting public FlatHashMap API
2 parents a9f95dc + dbd641f commit c9b2ef6

19 files changed

+284
-102
lines changed

src/compiler/scala/tools/nsc/transform/LambdaLift.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ abstract class LambdaLift extends InfoTransform {
154154
private def markCalled(sym: Symbol, owner: Symbol) {
155155
debuglog("mark called: " + sym + " of " + sym.owner + " is called by " + owner)
156156
symSet(called, owner) addEntry sym
157-
if (sym.enclClass != owner.enclClass) calledFromInner addEntry sym
157+
if (sym.enclClass != owner.enclClass) calledFromInner += sym
158158
}
159159

160160
/** The traverse function */

src/library/scala/collection/mutable/FlatHashTable.scala

+6-6
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
4444
*/
4545
@transient protected var sizemap: Array[Int] = null
4646

47-
@transient var seedvalue: Int = tableSizeSeed
47+
@transient protected var seedvalue: Int = tableSizeSeed
4848

4949
import HashTable.powerOfTwo
5050

@@ -109,7 +109,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
109109
}
110110

111111
/** Finds an entry in the hash table if such an element exists. */
112-
def findEntry(elem: A): Option[A] = {
112+
protected def findEntry(elem: A): Option[A] = {
113113
var h = index(elemHashCode(elem))
114114
var entry = table(h)
115115
while (null != entry && entry != elem) {
@@ -120,7 +120,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
120120
}
121121

122122
/** Checks whether an element is contained in the hash table. */
123-
def containsEntry(elem: A): Boolean = {
123+
protected def containsEntry(elem: A): Boolean = {
124124
var h = index(elemHashCode(elem))
125125
var entry = table(h)
126126
while (null != entry && entry != elem) {
@@ -133,7 +133,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
133133
/** Add entry if not yet in table.
134134
* @return Returns `true` if a new entry was added, `false` otherwise.
135135
*/
136-
def addEntry(elem: A) : Boolean = {
136+
protected def addEntry(elem: A) : Boolean = {
137137
var h = index(elemHashCode(elem))
138138
var entry = table(h)
139139
while (null != entry) {
@@ -150,7 +150,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
150150
}
151151

152152
/** Removes an entry from the hash table, returning an option value with the element, or `None` if it didn't exist. */
153-
def removeEntry(elem: A) : Option[A] = {
153+
protected def removeEntry(elem: A) : Option[A] = {
154154
if (tableDebug) checkConsistent()
155155
def precedes(i: Int, j: Int) = {
156156
val d = table.length >> 1
@@ -185,7 +185,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
185185
None
186186
}
187187

188-
def iterator: Iterator[A] = new AbstractIterator[A] {
188+
protected def iterator: Iterator[A] = new AbstractIterator[A] {
189189
private var i = 0
190190
def hasNext: Boolean = {
191191
while (i < table.length && (null == table(i))) i += 1

src/library/scala/collection/mutable/HashMap.scala

+18-11
Original file line numberDiff line numberDiff line change
@@ -49,30 +49,31 @@ extends AbstractMap[A, B]
4949
type Entry = DefaultEntry[A, B]
5050

5151
override def empty: HashMap[A, B] = HashMap.empty[A, B]
52-
override def clear() = clearTable()
52+
override def clear() { clearTable() }
5353
override def size: Int = tableSize
5454

5555
def this() = this(null)
5656

5757
override def par = new ParHashMap[A, B](hashTableContents)
5858

5959
// contains and apply overridden to avoid option allocations.
60-
override def contains(key: A) = findEntry(key) != null
60+
override def contains(key: A): Boolean = findEntry(key) != null
61+
6162
override def apply(key: A): B = {
6263
val result = findEntry(key)
63-
if (result == null) default(key)
64+
if (result eq null) default(key)
6465
else result.value
6566
}
6667

6768
def get(key: A): Option[B] = {
6869
val e = findEntry(key)
69-
if (e == null) None
70+
if (e eq null) None
7071
else Some(e.value)
7172
}
7273

7374
override def put(key: A, value: B): Option[B] = {
74-
val e = findEntry(key)
75-
if (e == null) { addEntry(new Entry(key, value)); None }
75+
val e = findOrAddEntry(key, value)
76+
if (e eq null) None
7677
else { val v = e.value; e.value = value; Some(v) }
7778
}
7879

@@ -85,9 +86,8 @@ extends AbstractMap[A, B]
8586
}
8687

8788
def += (kv: (A, B)): this.type = {
88-
val e = findEntry(kv._1)
89-
if (e == null) addEntry(new Entry(kv._1, kv._2))
90-
else e.value = kv._2
89+
val e = findOrAddEntry(kv._1, kv._2)
90+
if (e ne null) e.value = kv._2
9191
this
9292
}
9393

@@ -127,12 +127,19 @@ extends AbstractMap[A, B]
127127
if (!isSizeMapDefined) sizeMapInitAndRebuild
128128
} else sizeMapDisable
129129

130+
protected def createNewEntry[B1](key: A, value: B1): Entry = {
131+
new Entry(key, value.asInstanceOf[B])
132+
}
133+
130134
private def writeObject(out: java.io.ObjectOutputStream) {
131-
serializeTo(out, _.value)
135+
serializeTo(out, { entry =>
136+
out.writeObject(entry.key)
137+
out.writeObject(entry.value)
138+
})
132139
}
133140

134141
private def readObject(in: java.io.ObjectInputStream) {
135-
init[B](in, new Entry(_, _))
142+
init(in, createNewEntry(in.readObject().asInstanceOf[A], in.readObject()))
136143
}
137144

138145
}

src/library/scala/collection/mutable/HashSet.scala

+4-2
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ extends AbstractSet[A]
5353

5454
override def companion: GenericCompanion[HashSet] = HashSet
5555

56-
override def size = tableSize
56+
override def size: Int = tableSize
5757

5858
def contains(elem: A): Boolean = containsEntry(elem)
5959

@@ -67,7 +67,9 @@ extends AbstractSet[A]
6767

6868
override def remove(elem: A): Boolean = removeEntry(elem).isDefined
6969

70-
override def clear() = clearTable()
70+
override def clear() { clearTable() }
71+
72+
override def iterator: Iterator[A] = super[FlatHashTable].iterator
7173

7274
override def foreach[U](f: A => U) {
7375
var i = 0

src/library/scala/collection/mutable/HashTable.scala

+40-18
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ package mutable
3232
* @tparam A type of the elements contained in this hash table.
3333
*/
3434
trait HashTable[A, Entry >: Null <: HashEntry[A, Entry]] extends HashTable.HashUtils[A] {
35+
// Replacing Entry type parameter by abstract type member here allows to not expose to public
36+
// implementation-specific entry classes such as `DefaultEntry` or `LinkedEntry`.
37+
// However, I'm afraid it's too late now for such breaking change.
3538
import HashTable._
3639

3740
@transient protected var _loadFactor = defaultLoadFactor
@@ -52,7 +55,7 @@ trait HashTable[A, Entry >: Null <: HashEntry[A, Entry]] extends HashTable.HashU
5255
*/
5356
@transient protected var sizemap: Array[Int] = null
5457

55-
@transient var seedvalue: Int = tableSizeSeed
58+
@transient protected var seedvalue: Int = tableSizeSeed
5659

5760
protected def tableSizeSeed = Integer.bitCount(table.length - 1)
5861

@@ -75,11 +78,10 @@ trait HashTable[A, Entry >: Null <: HashEntry[A, Entry]] extends HashTable.HashU
7578
}
7679

7780
/**
78-
* Initializes the collection from the input stream. `f` will be called for each key/value pair
79-
* read from the input stream in the order determined by the stream. This is useful for
80-
* structures where iteration order is important (e.g. LinkedHashMap).
81+
* Initializes the collection from the input stream. `readEntry` will be called for each
82+
* entry to be read from the input stream.
8183
*/
82-
private[collection] def init[B](in: java.io.ObjectInputStream, f: (A, B) => Entry) {
84+
private[collection] def init(in: java.io.ObjectInputStream, readEntry: => Entry) {
8385
in.defaultReadObject
8486

8587
_loadFactor = in.readInt()
@@ -100,35 +102,34 @@ trait HashTable[A, Entry >: Null <: HashEntry[A, Entry]] extends HashTable.HashU
100102

101103
var index = 0
102104
while (index < size) {
103-
addEntry(f(in.readObject().asInstanceOf[A], in.readObject().asInstanceOf[B]))
105+
addEntry(readEntry)
104106
index += 1
105107
}
106108
}
107109

108110
/**
109111
* Serializes the collection to the output stream by saving the load factor, collection
110-
* size, collection keys and collection values. `value` is responsible for providing a value
111-
* from an entry.
112+
* size and collection entries. `writeEntry` is responsible for writing an entry to the stream.
112113
*
113-
* `foreach` determines the order in which the key/value pairs are saved to the stream. To
114+
* `foreachEntry` determines the order in which the key/value pairs are saved to the stream. To
114115
* deserialize, `init` should be used.
115116
*/
116-
private[collection] def serializeTo[B](out: java.io.ObjectOutputStream, value: Entry => B) {
117+
private[collection] def serializeTo(out: java.io.ObjectOutputStream, writeEntry: Entry => Unit) {
117118
out.defaultWriteObject
118119
out.writeInt(_loadFactor)
119120
out.writeInt(tableSize)
120121
out.writeInt(seedvalue)
121122
out.writeBoolean(isSizeMapDefined)
122-
foreachEntry { entry =>
123-
out.writeObject(entry.key)
124-
out.writeObject(value(entry))
125-
}
123+
124+
foreachEntry(writeEntry)
126125
}
127126

128127
/** Find entry with given key in table, null if not found.
129128
*/
130-
protected def findEntry(key: A): Entry = {
131-
val h = index(elemHashCode(key))
129+
protected def findEntry(key: A): Entry =
130+
findEntry0(key, index(elemHashCode(key)))
131+
132+
private[this] def findEntry0(key: A, h: Int): Entry = {
132133
var e = table(h).asInstanceOf[Entry]
133134
while (e != null && !elemEquals(e.key, key)) e = e.next
134135
e
@@ -138,7 +139,10 @@ trait HashTable[A, Entry >: Null <: HashEntry[A, Entry]] extends HashTable.HashU
138139
* pre: no entry with same key exists
139140
*/
140141
protected def addEntry(e: Entry) {
141-
val h = index(elemHashCode(e.key))
142+
addEntry0(e, index(elemHashCode(e.key)))
143+
}
144+
145+
private[this] def addEntry0(e: Entry, h: Int) {
142146
e.next = table(h).asInstanceOf[Entry]
143147
table(h) = e
144148
tableSize = tableSize + 1
@@ -147,6 +151,24 @@ trait HashTable[A, Entry >: Null <: HashEntry[A, Entry]] extends HashTable.HashU
147151
resize(2 * table.length)
148152
}
149153

154+
/** Find entry with given key in table, or add new one if not found.
155+
* May be somewhat faster then `findEntry`/`addEntry` pair as it
156+
* computes entry's hash index only once.
157+
* Returns entry found in table or null.
158+
* New entries are created by calling `createNewEntry` method.
159+
*/
160+
protected def findOrAddEntry[B](key: A, value: B): Entry = {
161+
val h = index(elemHashCode(key))
162+
val e = findEntry0(key, h)
163+
if (e ne null) e else { addEntry0(createNewEntry(key, value), h); null }
164+
}
165+
166+
/** Creates new entry to be immediately inserted into the hashtable.
167+
* This method is guaranteed to be called only once and in case that the entry
168+
* will be added. In other words, an implementation may be side-effecting.
169+
*/
170+
protected def createNewEntry[B](key: A, value: B): Entry
171+
150172
/** Remove entry from table if present.
151173
*/
152174
protected def removeEntry(key: A) : Entry = {
@@ -195,7 +217,7 @@ trait HashTable[A, Entry >: Null <: HashEntry[A, Entry]] extends HashTable.HashU
195217
}
196218

197219
/** Avoid iterator for a 2x faster traversal. */
198-
protected def foreachEntry[C](f: Entry => C) {
220+
protected def foreachEntry[U](f: Entry => U) {
199221
val iterTable = table
200222
var idx = lastPopulatedIndex
201223
var es = iterTable(idx)

src/library/scala/collection/mutable/LinkedHashMap.scala

+18-25
Original file line numberDiff line numberDiff line change
@@ -67,23 +67,9 @@ class LinkedHashMap[A, B] extends AbstractMap[A, B]
6767
}
6868

6969
override def put(key: A, value: B): Option[B] = {
70-
val e = findEntry(key)
71-
if (e == null) {
72-
val e = new Entry(key, value)
73-
addEntry(e)
74-
updateLinkedEntries(e)
75-
None
76-
} else {
77-
val v = e.value
78-
e.value = value
79-
Some(v)
80-
}
81-
}
82-
83-
private def updateLinkedEntries(e: Entry) {
84-
if (firstEntry == null) firstEntry = e
85-
else { lastEntry.later = e; e.earlier = lastEntry }
86-
lastEntry = e
70+
val e = findOrAddEntry(key, value)
71+
if (e eq null) None
72+
else { val v = e.value; e.value = value; Some(v) }
8773
}
8874

8975
override def remove(key: A): Option[B] = {
@@ -143,38 +129,45 @@ class LinkedHashMap[A, B] extends AbstractMap[A, B]
143129
else Iterator.empty.next
144130
}
145131

146-
override def foreach[U](f: ((A, B)) => U) = {
132+
override def foreach[U](f: ((A, B)) => U) {
147133
var cur = firstEntry
148134
while (cur ne null) {
149135
f((cur.key, cur.value))
150136
cur = cur.later
151137
}
152138
}
153139

154-
protected override def foreachEntry[C](f: Entry => C) {
140+
protected override def foreachEntry[U](f: Entry => U) {
155141
var cur = firstEntry
156142
while (cur ne null) {
157143
f(cur)
158144
cur = cur.later
159145
}
160146
}
161147

148+
protected def createNewEntry[B1](key: A, value: B1): Entry = {
149+
val e = new Entry(key, value.asInstanceOf[B])
150+
if (firstEntry eq null) firstEntry = e
151+
else { lastEntry.later = e; e.earlier = lastEntry }
152+
lastEntry = e
153+
e
154+
}
155+
162156
override def clear() {
163157
clearTable()
164158
firstEntry = null
165159
}
166160

167161
private def writeObject(out: java.io.ObjectOutputStream) {
168-
serializeTo(out, _.value)
162+
serializeTo(out, { entry =>
163+
out.writeObject(entry.key)
164+
out.writeObject(entry.value)
165+
})
169166
}
170167

171168
private def readObject(in: java.io.ObjectInputStream) {
172169
firstEntry = null
173170
lastEntry = null
174-
init[B](in, { (key, value) =>
175-
val entry = new Entry(key, value)
176-
updateLinkedEntries(entry)
177-
entry
178-
})
171+
init(in, createNewEntry(in.readObject().asInstanceOf[A], in.readObject()))
179172
}
180173
}

0 commit comments

Comments
 (0)