Skip to content

Commit 03c8d50

Browse files
Dylan WongHeartSaVioR
authored andcommitted
[SPARK-52989][SS][4.0] Add explicit close() API to State Store iterators
### What changes were proposed in this pull request? Back port of #51701. Add explicit ```close()``` API to State Store iterators. This PR changes the ```ReadStateStore``` trait's ```prefixScan``` and ```iterator``` methods to return ```StateStoreIterator[UnsafeRowPair]``` instead of ```Iterator[UnsafeRowPair]```. This new type has the ```close()``` method. The ```exists()``` method of MapStateImpl is also changed to close the iterator explicitly when it is no longer needed. Additionally ```close()``` calls are added to in TimerStateImpl, MapStateImplWithTTL in their iterators that consume the state store iterators. ### Why are the changes needed? These changes expose the close() method on state store iterators. This allows users of the StateStoreIterator to explicitly close it and its underlying resources when it's no longer needed. This change prevents the issue of having to hold on to the iterators until all rows are consumed and close() is called, or until the task completion/failure listener calls close() on the iterators. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing unit tests, tests for the wrapper ```StateStoreIterator``` class and new test to verify that ```close()``` closes the underlying RocksDB iterator. ### Was this patch authored or co-authored using generative AI tooling? No Closes #51863 from dylanwong250/SPARK-52989-backport. Authored-by: Dylan Wong <[email protected]> Signed-off-by: Jungtaek Lim <[email protected]>
1 parent a4ed4c8 commit 03c8d50

File tree

11 files changed

+202
-36
lines changed

11 files changed

+202
-36
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@ class MapStateImpl[K, V](
5656

5757
/** Whether state exists or not. */
5858
override def exists(): Boolean = {
59-
store.prefixScan(stateTypesEncoder.encodeGroupingKey(), stateName).nonEmpty
59+
val iter = store.prefixScan(stateTypesEncoder.encodeGroupingKey(), stateName)
60+
val result = iter.nonEmpty
61+
iter.close()
62+
result
6063
}
6164

6265
/** Get the state value if it exists */

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ metrics: Map[String, SQLMetric])
128128
}
129129
}
130130

131-
override protected def close(): Unit = {}
131+
override protected def close(): Unit = {
132+
unsafeRowPairIterator.close()
133+
}
132134
}
133135
}
134136

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,9 @@ class TimerStateImpl(
199199
}
200200
}
201201

202-
override protected def close(): Unit = { }
202+
override protected def close(): Unit = {
203+
iter.close()
204+
}
203205
}
204206
}
205207
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,9 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
8383

8484
override def get(key: UnsafeRow, colFamilyName: String): UnsafeRow = map.get(key)
8585

86-
override def iterator(colFamilyName: String): Iterator[UnsafeRowPair] = {
87-
map.iterator()
86+
override def iterator(colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
87+
val iter = map.iterator()
88+
new StateStoreIterator(iter)
8889
}
8990

9091
override def abort(): Unit = {}
@@ -93,9 +94,11 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
9394
s"HDFSReadStateStore[id=(op=${id.operatorId},part=${id.partitionId}),dir=$baseDir]"
9495
}
9596

96-
override def prefixScan(prefixKey: UnsafeRow, colFamilyName: String):
97-
Iterator[UnsafeRowPair] = {
98-
map.prefixScan(prefixKey)
97+
override def prefixScan(
98+
prefixKey: UnsafeRow,
99+
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
100+
val iter = map.prefixScan(prefixKey)
101+
new StateStoreIterator(iter)
99102
}
100103

101104
override def valuesIterator(key: UnsafeRow, colFamilyName: String): Iterator[UnsafeRow] = {
@@ -198,15 +201,18 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
198201
* Get an iterator of all the store data.
199202
* This can be called only after committing all the updates made in the current thread.
200203
*/
201-
override def iterator(colFamilyName: String): Iterator[UnsafeRowPair] = {
204+
override def iterator(colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
202205
assertUseOfDefaultColFamily(colFamilyName)
203-
mapToUpdate.iterator()
206+
val iter = mapToUpdate.iterator()
207+
new StateStoreIterator(iter)
204208
}
205209

206-
override def prefixScan(prefixKey: UnsafeRow, colFamilyName: String):
207-
Iterator[UnsafeRowPair] = {
210+
override def prefixScan(
211+
prefixKey: UnsafeRow,
212+
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
208213
assertUseOfDefaultColFamily(colFamilyName)
209-
mapToUpdate.prefixScan(prefixKey)
214+
val iter = mapToUpdate.prefixScan(prefixKey)
215+
new StateStoreIterator(iter)
210216
}
211217

212218
override def metrics: StateStoreMetrics = {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,7 @@ class RocksDB(
860860
/**
861861
* Get an iterator of all committed and uncommitted key-value pairs.
862862
*/
863-
def iterator(): Iterator[ByteArrayPair] = {
863+
def iterator(): NextIterator[ByteArrayPair] = {
864864
val iter = db.newIterator()
865865
logInfo(log"Getting iterator from version ${MDC(LogKeys.LOADED_VERSION, loadedVersion)}")
866866
iter.seekToFirst()
@@ -896,7 +896,7 @@ class RocksDB(
896896
/**
897897
* Get an iterator of all committed and uncommitted key-value pairs for the given column family.
898898
*/
899-
def iterator(cfName: String): Iterator[ByteArrayPair] = {
899+
def iterator(cfName: String): NextIterator[ByteArrayPair] = {
900900
if (!useColumnFamilies) {
901901
iterator()
902902
} else {
@@ -945,7 +945,7 @@ class RocksDB(
945945

946946
def prefixScan(
947947
prefix: Array[Byte],
948-
cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[ByteArrayPair] = {
948+
cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): NextIterator[ByteArrayPair] = {
949949
val iter = db.newIterator()
950950
val updatedPrefix = if (useColumnFamilies) {
951951
encodeStateRowWithPrefix(prefix, cfName)

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -179,16 +179,17 @@ private[sql] class RocksDBStateStoreProvider
179179
rocksDB.remove(kvEncoder._1.encodeKey(key), colFamilyName)
180180
}
181181

182-
override def iterator(colFamilyName: String): Iterator[UnsafeRowPair] = {
182+
override def iterator(colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
183183
// Note this verify function only verify on the colFamilyName being valid,
184184
// we are actually doing prefix when useColumnFamilies,
185185
// but pass "iterator" to throw correct error message
186186
verifyColFamilyOperations("iterator", colFamilyName)
187187
val kvEncoder = keyValueEncoderMap.get(colFamilyName)
188188
val rowPair = new UnsafeRowPair()
189-
190189
if (useColumnFamilies) {
191-
rocksDB.iterator(colFamilyName).map { kv =>
190+
val rocksDbIter = rocksDB.iterator(colFamilyName)
191+
192+
val iter = rocksDbIter.map { kv =>
192193
rowPair.withRows(kvEncoder._1.decodeKey(kv.key),
193194
kvEncoder._2.decodeValue(kv.value))
194195
if (!isValidated && rowPair.value != null && !useColumnFamilies) {
@@ -198,8 +199,12 @@ private[sql] class RocksDBStateStoreProvider
198199
}
199200
rowPair
200201
}
202+
203+
new StateStoreIterator(iter, rocksDbIter.closeIfNeeded)
201204
} else {
202-
rocksDB.iterator().map { kv =>
205+
val rocksDbIter = rocksDB.iterator()
206+
207+
val iter = rocksDbIter.map { kv =>
203208
rowPair.withRows(kvEncoder._1.decodeKey(kv.key),
204209
kvEncoder._2.decodeValue(kv.value))
205210
if (!isValidated && rowPair.value != null && !useColumnFamilies) {
@@ -209,11 +214,14 @@ private[sql] class RocksDBStateStoreProvider
209214
}
210215
rowPair
211216
}
217+
218+
new StateStoreIterator(iter, rocksDbIter.closeIfNeeded)
212219
}
213220
}
214221

215-
override def prefixScan(prefixKey: UnsafeRow, colFamilyName: String):
216-
Iterator[UnsafeRowPair] = {
222+
override def prefixScan(
223+
prefixKey: UnsafeRow,
224+
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
217225
verifyColFamilyOperations("prefixScan", colFamilyName)
218226

219227
val kvEncoder = keyValueEncoderMap.get(colFamilyName)
@@ -222,11 +230,15 @@ private[sql] class RocksDBStateStoreProvider
222230

223231
val rowPair = new UnsafeRowPair()
224232
val prefix = kvEncoder._1.encodePrefixKey(prefixKey)
225-
rocksDB.prefixScan(prefix, colFamilyName).map { kv =>
233+
234+
val rocksDbIter = rocksDB.prefixScan(prefix, colFamilyName)
235+
val iter = rocksDbIter.map { kv =>
226236
rowPair.withRows(kvEncoder._1.decodeKey(kv.key),
227237
kvEncoder._2.decodeValue(kv.value))
228238
rowPair
229239
}
240+
241+
new StateStoreIterator(iter, rocksDbIter.closeIfNeeded)
230242
}
231243

232244
var checkpointInfo: Option[StateStoreCheckpointInfo] = None

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.execution.streaming.state
1919

20+
import java.io.Closeable
2021
import java.util.UUID
2122
import java.util.concurrent.{ScheduledFuture, TimeUnit}
2223
import javax.annotation.concurrent.GuardedBy
@@ -41,6 +42,25 @@ import org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo, Stre
4142
import org.apache.spark.sql.types.StructType
4243
import org.apache.spark.util.{NextIterator, ThreadUtils, Utils}
4344

45+
/**
46+
* Represents an iterator that provides additional functionalities for state store use cases.
47+
*
48+
* `close()` is useful for freeing underlying iterator resources when the iterator is no longer
49+
* needed.
50+
*
51+
* The caller MUST call `close()` on the iterator if it was not fully consumed, and it is no
52+
* longer needed.
53+
*/
54+
class StateStoreIterator[A](
55+
val iter: Iterator[A],
56+
val onClose: () => Unit = () => {}) extends Iterator[A] with Closeable {
57+
override def hasNext: Boolean = iter.hasNext
58+
59+
override def next(): A = iter.next()
60+
61+
override def close(): Unit = onClose()
62+
}
63+
4464
sealed trait StateStoreEncoding {
4565
override def toString: String = this match {
4666
case StateStoreEncoding.UnsafeRow => "unsaferow"
@@ -106,10 +126,11 @@ trait ReadStateStore {
106126
*/
107127
def prefixScan(
108128
prefixKey: UnsafeRow,
109-
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair]
129+
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): StateStoreIterator[UnsafeRowPair]
110130

111131
/** Return an iterator containing all the key-value pairs in the StateStore. */
112-
def iterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair]
132+
def iterator(
133+
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): StateStoreIterator[UnsafeRowPair]
113134

114135
/**
115136
* Clean up the resource.
@@ -196,8 +217,8 @@ trait StateStore extends ReadStateStore {
196217
* performed after initialization of the iterator. Callers should perform all updates before
197218
* calling this method if all updates should be visible in the returned iterator.
198219
*/
199-
override def iterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME):
200-
Iterator[UnsafeRowPair]
220+
override def iterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME)
221+
: StateStoreIterator[UnsafeRowPair]
201222

202223
/** Current metrics of the state store */
203224
def metrics: StateStoreMetrics
@@ -229,14 +250,14 @@ class WrappedReadStateStore(store: StateStore) extends ReadStateStore {
229250
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): UnsafeRow = store.get(key,
230251
colFamilyName)
231252

232-
override def iterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME):
233-
Iterator[UnsafeRowPair] = store.iterator(colFamilyName)
253+
override def iterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME)
254+
: StateStoreIterator[UnsafeRowPair] = store.iterator(colFamilyName)
234255

235256
override def abort(): Unit = store.abort()
236257

237258
override def prefixScan(prefixKey: UnsafeRow,
238-
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] =
239-
store.prefixScan(prefixKey, colFamilyName)
259+
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME)
260+
: StateStoreIterator[UnsafeRowPair] = store.prefixScan(prefixKey, colFamilyName)
240261

241262
override def valuesIterator(key: UnsafeRow, colFamilyName: String): Iterator[UnsafeRow] = {
242263
store.valuesIterator(key, colFamilyName)

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@ class MemoryStateStore extends StateStore() {
2626
import scala.jdk.CollectionConverters._
2727
private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow]
2828

29-
override def iterator(colFamilyName: String): Iterator[UnsafeRowPair] = {
30-
map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) }
29+
override def iterator(colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
30+
val iter =
31+
map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) }
32+
new StateStoreIterator(iter)
3133
}
3234

3335
override def createColFamilyIfAbsent(
@@ -63,7 +65,9 @@ class MemoryStateStore extends StateStore() {
6365

6466
override def hasCommitted: Boolean = true
6567

66-
override def prefixScan(prefixKey: UnsafeRow, colFamilyName: String): Iterator[UnsafeRowPair] = {
68+
override def prefixScan(
69+
prefixKey: UnsafeRow,
70+
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
6771
throw new UnsupportedOperationException("Doesn't support prefix scan!")
6872
}
6973

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,14 @@ case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends Sta
7676

7777
override def prefixScan(
7878
prefixKey: UnsafeRow,
79-
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] = {
79+
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME)
80+
: StateStoreIterator[UnsafeRowPair] = {
8081
innerStore.prefixScan(prefixKey, colFamilyName)
8182
}
8283

8384
override def iterator(
84-
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] = {
85+
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME)
86+
: StateStoreIterator[UnsafeRowPair] = {
8587
innerStore.iterator(colFamilyName)
8688
}
8789

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1610,6 +1610,80 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
16101610
}
16111611
}
16121612

1613+
testWithColumnFamiliesAndEncodingTypes(
1614+
"closing the iterator also closes the underlying rocksdb iterator",
1615+
TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled =>
1616+
1617+
// use the same schema as value schema for single col key schema
1618+
tryWithProviderResource(newStoreProvider(valueSchema,
1619+
RangeKeyScanStateEncoderSpec(valueSchema, Seq(0)), colFamiliesEnabled)) { provider =>
1620+
val store = provider.getStore(0)
1621+
try {
1622+
val cfName = if (colFamiliesEnabled) "testColFamily" else "default"
1623+
if (colFamiliesEnabled) {
1624+
store.createColFamilyIfAbsent(cfName,
1625+
valueSchema, valueSchema,
1626+
RangeKeyScanStateEncoderSpec(valueSchema, Seq(0)))
1627+
}
1628+
1629+
val timerTimestamps = Seq(1, 2, 3, 22)
1630+
timerTimestamps.foreach { ts =>
1631+
val keyRow = dataToValueRow(ts)
1632+
val valueRow = dataToValueRow(1)
1633+
store.put(keyRow, valueRow, cfName)
1634+
assert(valueRowToData(store.get(keyRow, cfName)) === 1)
1635+
}
1636+
1637+
val iter1 = store.iterator(cfName)
1638+
for (i <- 1 to 4) {
1639+
assert(iter1.hasNext)
1640+
iter1.next()
1641+
}
1642+
// We were fully able to process the 4 elements
1643+
assert(!iter1.hasNext)
1644+
1645+
val iter2 = store.iterator(cfName)
1646+
for (i <- 1 to 2) {
1647+
assert(iter2.hasNext)
1648+
iter2.next()
1649+
}
1650+
// Close the iterator
1651+
iter2.close()
1652+
// After closing, this will call AbstractRocksIterator.isValid which should throw and
1653+
// exception since it no longer owns the underlying rocksdb iterator
1654+
val exception1 = intercept[AssertionError] {
1655+
iter2.next()
1656+
}
1657+
// Check that the exception is thrown from AbstractRocksIterator.isValid
1658+
assert(exception1.getStackTrace()(0).getClassName.contains("AbstractRocksIterator"))
1659+
assert(exception1.getStackTrace()(0).getMethodName.contains("isValid"))
1660+
1661+
// also check for prefix scan
1662+
val prefix = dataToValueRow(2)
1663+
val iter3 = store.prefixScan(prefix, cfName)
1664+
1665+
iter3.next()
1666+
assert(!iter3.hasNext)
1667+
1668+
val iter4 = store.prefixScan(prefix, cfName)
1669+
// Immediately close the iterator without calling next
1670+
iter4.close()
1671+
1672+
// Since we closed the iterator, this will throw an exception when we try to call next
1673+
val exception2 = intercept[AssertionError] {
1674+
iter4.next()
1675+
}
1676+
// Check that the exception is thrown from AbstractRocksIterator.isValid
1677+
assert(exception2.getStackTrace()(0).getClassName.contains("AbstractRocksIterator"))
1678+
assert(exception2.getStackTrace()(0).getMethodName.contains("isValid"))
1679+
1680+
store.commit()
1681+
} finally {
1682+
if (!store.hasCommitted) store.abort()
1683+
}
1684+
}
1685+
}
1686+
16131687
test("validate rocksdb values iterator correctness") {
16141688
withSQLConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1") {
16151689
tryWithProviderResource(newStoreProvider(useColumnFamilies = true,

0 commit comments

Comments
 (0)