-
Notifications
You must be signed in to change notification settings - Fork 28.8k
[SPARK-53001] Integrate RocksDB Memory Usage with the Unified Memory Manager #51708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,11 +17,19 @@ | |
|
||
package org.apache.spark.memory | ||
|
||
import java.util.concurrent.{ConcurrentHashMap, ScheduledExecutorService, TimeUnit} | ||
import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} | ||
|
||
import scala.jdk.CollectionConverters._ | ||
import scala.util.control.NonFatal | ||
|
||
import org.apache.spark.{SparkConf, SparkIllegalArgumentException} | ||
import org.apache.spark.internal.{config, MDC} | ||
import org.apache.spark.internal.{config, Logging, LogKeys, MDC} | ||
import org.apache.spark.internal.LogKeys._ | ||
import org.apache.spark.internal.config.Tests._ | ||
import org.apache.spark.internal.config.UNMANAGED_MEMORY_POLLING_INTERVAL | ||
import org.apache.spark.storage.BlockId | ||
import org.apache.spark.util.{ThreadUtils, Utils} | ||
|
||
/** | ||
* A [[MemoryManager]] that enforces a soft boundary between execution and storage such that | ||
|
@@ -56,7 +64,47 @@ private[spark] class UnifiedMemoryManager( | |
conf, | ||
numCores, | ||
onHeapStorageRegionSize, | ||
maxHeapMemory - onHeapStorageRegionSize) { | ||
maxHeapMemory - onHeapStorageRegionSize) with Logging { | ||
|
||
/** | ||
* Unmanaged memory tracking infrastructure. | ||
* | ||
* Unmanaged memory refers to memory consumed by components that manage their own memory | ||
* outside of Spark's unified memory management system. Examples include: | ||
* - RocksDB state stores used in structured streaming | ||
* - Native libraries with their own memory management | ||
* - Off-heap caches managed by unmanaged systems | ||
ericm-db marked this conversation as resolved.
Show resolved
Hide resolved
|
||
* | ||
* We track this memory to: | ||
* 1. Provide visibility into total memory usage on executors | ||
* 2. Prevent OOM errors by accounting for it in memory allocation decisions | ||
* 3. Enable better debugging and monitoring of memory-intensive applications | ||
* | ||
* The polling mechanism periodically queries registered unmanaged memory consumers | ||
* to detect inactive consumers and handle cleanup. | ||
*/ | ||
// Configuration for polling interval (in milliseconds) | ||
private val unmanagedMemoryPollingIntervalMs = conf.get(UNMANAGED_MEMORY_POLLING_INTERVAL) | ||
// Initialize background polling if enabled | ||
if (unmanagedMemoryPollingIntervalMs > 0) { | ||
UnifiedMemoryManager.startPollingIfNeeded(unmanagedMemoryPollingIntervalMs) | ||
} | ||
|
||
/** | ||
* Get the current unmanaged memory usage in bytes for a specific memory mode. | ||
* @param memoryMode The memory mode (ON_HEAP or OFF_HEAP) to get usage for | ||
* @return The current unmanaged memory usage in bytes | ||
*/ | ||
private def getUnmanagedMemoryUsed(memoryMode: MemoryMode): Long = { | ||
// Only consider unmanaged memory if polling is enabled | ||
if (unmanagedMemoryPollingIntervalMs <= 0) { | ||
return 0L | ||
} | ||
memoryMode match { | ||
case MemoryMode.ON_HEAP => UnifiedMemoryManager.unmanagedOnHeapUsed.get() | ||
case MemoryMode.OFF_HEAP => UnifiedMemoryManager.unmanagedOffHeapUsed.get() | ||
} | ||
} | ||
|
||
private def assertInvariants(): Unit = { | ||
assert(onHeapExecutionMemoryPool.poolSize + onHeapStorageMemoryPool.poolSize == maxHeapMemory) | ||
|
@@ -140,9 +188,15 @@ private[spark] class UnifiedMemoryManager( | |
* in execution memory allocation across tasks, Otherwise, a task may occupy more than | ||
* its fair share of execution memory, mistakenly thinking that other tasks can acquire | ||
* the portion of storage memory that cannot be evicted. | ||
* | ||
* This also factors in unmanaged memory usage to ensure we don't over-allocate memory | ||
* when unmanaged components are consuming significant memory. | ||
*/ | ||
def computeMaxExecutionPoolSize(): Long = { | ||
maxMemory - math.min(storagePool.memoryUsed, storageRegionSize) | ||
val unmanagedMemory = getUnmanagedMemoryUsed(memoryMode) | ||
val availableMemory = maxMemory - math.min(storagePool.memoryUsed, storageRegionSize) | ||
// Reduce available memory by unmanaged memory usage to prevent over-allocation | ||
math.max(0L, availableMemory - unmanagedMemory) | ||
} | ||
|
||
executionPool.acquireMemory( | ||
|
@@ -165,11 +219,21 @@ private[spark] class UnifiedMemoryManager( | |
offHeapStorageMemoryPool, | ||
maxOffHeapStorageMemory) | ||
} | ||
if (numBytes > maxMemory) { | ||
|
||
// Factor in unmanaged memory usage for the specific memory mode | ||
val unmanagedMemory = getUnmanagedMemoryUsed(memoryMode) | ||
val effectiveMaxMemory = math.max(0L, maxMemory - unmanagedMemory) | ||
|
||
if (numBytes > effectiveMaxMemory) { | ||
// Fail fast if the block simply won't fit | ||
logInfo(log"Will not store ${MDC(BLOCK_ID, blockId)} as the required space" + | ||
log" (${MDC(NUM_BYTES, numBytes)} bytes) exceeds our" + | ||
log" memory limit (${MDC(NUM_BYTES_MAX, maxMemory)} bytes)") | ||
log" memory limit (${MDC(NUM_BYTES_MAX, effectiveMaxMemory)} bytes)" + | ||
(if (unmanagedMemory > 0) { | ||
log" (unmanaged memory usage: ${MDC(NUM_BYTES, unmanagedMemory)} bytes)" | ||
} else { | ||
log"" | ||
})) | ||
return false | ||
} | ||
if (numBytes > storagePool.memoryFree) { | ||
|
@@ -191,14 +255,189 @@ private[spark] class UnifiedMemoryManager( | |
} | ||
} | ||
|
||
object UnifiedMemoryManager { | ||
object UnifiedMemoryManager extends Logging { | ||
|
||
// Set aside a fixed amount of memory for non-storage, non-execution purposes. | ||
// This serves a function similar to `spark.memory.fraction`, but guarantees that we reserve | ||
// sufficient memory for the system even for small heaps. E.g. if we have a 1GB JVM, then | ||
// the memory used for execution and storage will be (1024 - 300) * 0.6 = 434MB by default. | ||
private val RESERVED_SYSTEM_MEMORY_BYTES = 300 * 1024 * 1024 | ||
|
||
private val unmanagedMemoryConsumers = | ||
new ConcurrentHashMap[UnmanagedMemoryConsumerId, UnmanagedMemoryConsumer] | ||
|
||
// Cached unmanaged memory usage values updated by polling | ||
private val unmanagedOnHeapUsed = new AtomicLong(0L) | ||
private val unmanagedOffHeapUsed = new AtomicLong(0L) | ||
|
||
// Atomic flag to ensure polling is only started once per JVM | ||
private val pollingStarted = new AtomicBoolean(false) | ||
|
||
/** | ||
* Register an unmanaged memory consumer to track its memory usage. | ||
* | ||
* Unmanaged memory consumers are components that manage their own memory outside | ||
* of Spark's unified memory management system. By registering, their memory usage | ||
* will be periodically polled and factored into Spark's memory allocation decisions. | ||
* | ||
* @param unmanagedMemoryConsumer The consumer to register for memory tracking | ||
*/ | ||
def registerUnmanagedMemoryConsumer( | ||
unmanagedMemoryConsumer: UnmanagedMemoryConsumer): Unit = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indentation, @ericm-db ? |
||
val id = unmanagedMemoryConsumer.unmanagedMemoryConsumerId | ||
unmanagedMemoryConsumers.put(id, unmanagedMemoryConsumer) | ||
} | ||
|
||
/** | ||
* Unregister an unmanaged memory consumer. | ||
* This should be called when a component is shutting down to prevent memory leaks | ||
* and ensure accurate memory tracking. | ||
* | ||
* @param unmanagedMemoryConsumer The consumer to unregister. Only used in tests | ||
*/ | ||
private[spark] def unregisterUnmanagedMemoryConsumer( | ||
unmanagedMemoryConsumer: UnmanagedMemoryConsumer): Unit = { | ||
val id = unmanagedMemoryConsumer.unmanagedMemoryConsumerId | ||
unmanagedMemoryConsumers.remove(id) | ||
} | ||
|
||
|
||
/** | ||
* Get the current memory usage in bytes for a specific component type. | ||
* @param componentType The type of component to filter by (e.g., "RocksDB") | ||
* @return Total memory usage in bytes for the specified component type | ||
*/ | ||
def getMemoryByComponentType(componentType: String): Long = { | ||
unmanagedMemoryConsumers.asScala.values.toSeq | ||
.filter(_.unmanagedMemoryConsumerId.componentType == componentType) | ||
.map { memoryUser => | ||
try { | ||
memoryUser.getMemBytesUsed | ||
} catch { | ||
case e: Exception => | ||
0L | ||
} | ||
} | ||
.sum | ||
} | ||
|
||
/** | ||
* Clear all unmanaged memory users. | ||
* This is useful during executor shutdown or cleanup. | ||
* Since each executor runs in its own JVM, this clears all users for this executor. | ||
*/ | ||
def clearUnmanagedMemoryUsers(): Unit = { | ||
unmanagedMemoryConsumers.clear() | ||
// Reset cached values when clearing consumers | ||
unmanagedOnHeapUsed.set(0L) | ||
unmanagedOffHeapUsed.set(0L) | ||
} | ||
|
||
// Shared polling infrastructure - only one polling thread per JVM | ||
@volatile private var unmanagedMemoryPoller: ScheduledExecutorService = _ | ||
|
||
/** | ||
* Start unmanaged memory polling if not already started. | ||
* This ensures only one polling thread is created per JVM, regardless of how many | ||
* UnifiedMemoryManager instances are created. | ||
*/ | ||
private[memory] def startPollingIfNeeded(pollingIntervalMs: Long): Unit = { | ||
if (pollingStarted.compareAndSet(false, true)) { | ||
unmanagedMemoryPoller = ThreadUtils.newDaemonSingleThreadScheduledExecutor( | ||
"unmanaged-memory-poller") | ||
|
||
val pollingTask = new Runnable { | ||
override def run(): Unit = Utils.tryLogNonFatalError { | ||
pollUnmanagedMemoryUsers() | ||
} | ||
} | ||
|
||
unmanagedMemoryPoller.scheduleAtFixedRate( | ||
pollingTask, | ||
0L, // initial delay | ||
pollingIntervalMs, | ||
TimeUnit.MILLISECONDS) | ||
|
||
logInfo(log"Unmanaged memory polling started with interval " + | ||
log"${MDC(LogKeys.TIME, pollingIntervalMs)}ms") | ||
} | ||
} | ||
|
||
private def pollUnmanagedMemoryUsers(): Unit = { | ||
val consumers = unmanagedMemoryConsumers.asScala.toMap | ||
|
||
// Get memory usage for each consumer, handling failures gracefully | ||
val memoryUsages = consumers.map { case (userId, memoryUser) => | ||
try { | ||
val memoryUsed = memoryUser.getMemBytesUsed | ||
if (memoryUsed == -1L) { | ||
logDebug(log"Unmanaged memory consumer ${MDC(LogKeys.OBJECT_ID, userId.toString)} " + | ||
log"is no longer active, marking for removal") | ||
(userId, memoryUser, None) // Mark for removal | ||
} else if (memoryUsed < 0L) { | ||
logWarning(log"Invalid memory usage value ${MDC(LogKeys.NUM_BYTES, memoryUsed)} " + | ||
log"from unmanaged memory user ${MDC(LogKeys.OBJECT_ID, userId.toString)}") | ||
(userId, memoryUser, Some(0L)) // Treat as 0 | ||
} else { | ||
(userId, memoryUser, Some(memoryUsed)) | ||
} | ||
} catch { | ||
case NonFatal(e) => | ||
logWarning(log"Failed to get memory usage for unmanaged memory user " + | ||
log"${MDC(LogKeys.OBJECT_ID, userId.toString)} ${MDC(LogKeys.EXCEPTION, e)}") | ||
(userId, memoryUser, Some(0L)) // Treat as 0 on error | ||
} | ||
} | ||
|
||
// Remove inactive consumers | ||
memoryUsages.filter(_._3.isEmpty).foreach { case (userId, _, _) => | ||
unmanagedMemoryConsumers.remove(userId) | ||
logInfo(log"Removed inactive unmanaged memory consumer " + | ||
log"${MDC(LogKeys.OBJECT_ID, userId.toString)}") | ||
} | ||
// Calculate total memory usage by mode | ||
val activeUsages = memoryUsages.filter(_._3.isDefined) | ||
val onHeapTotal = activeUsages | ||
.filter(_._2.memoryMode == MemoryMode.ON_HEAP) | ||
.map(_._3.get) | ||
.sum | ||
val offHeapTotal = activeUsages | ||
.filter(_._2.memoryMode == MemoryMode.OFF_HEAP) | ||
.map(_._3.get) | ||
.sum | ||
// Update cached values atomically | ||
unmanagedOnHeapUsed.set(onHeapTotal) | ||
unmanagedOffHeapUsed.set(offHeapTotal) | ||
// Log polling results for monitoring | ||
val totalMemoryUsed = onHeapTotal + offHeapTotal | ||
val numConsumers = activeUsages.size | ||
logDebug(s"Unmanaged memory polling completed: $numConsumers consumers, " + | ||
s"total memory used: ${totalMemoryUsed} bytes " + | ||
s"(on-heap: ${onHeapTotal}, off-heap: ${offHeapTotal})") | ||
} | ||
|
||
/** | ||
* Shutdown the unmanaged memory polling thread. Only used in tests | ||
*/ | ||
private[spark] def shutdownUnmanagedMemoryPoller(): Unit = { | ||
synchronized { | ||
if (unmanagedMemoryPoller != null) { | ||
unmanagedMemoryPoller.shutdown() | ||
try { | ||
if (!unmanagedMemoryPoller.awaitTermination(5, TimeUnit.SECONDS)) { | ||
unmanagedMemoryPoller.shutdownNow() | ||
} | ||
} catch { | ||
case _: InterruptedException => | ||
Thread.currentThread().interrupt() | ||
} | ||
unmanagedMemoryPoller = null | ||
pollingStarted.set(false) | ||
logInfo(log"Unmanaged memory poller shutdown complete") | ||
} | ||
} | ||
} | ||
|
||
def apply(conf: SparkConf, numCores: Int): UnifiedMemoryManager = { | ||
val maxMemory = getMaxMemory(conf) | ||
new UnifiedMemoryManager( | ||
|
@@ -242,3 +481,57 @@ object UnifiedMemoryManager { | |
(usableMemory * memoryFraction).toLong | ||
} | ||
} | ||
|
||
/** | ||
* Identifier for an unmanaged memory consumer. | ||
* | ||
* @param componentType The type of component (e.g., "RocksDB", "NativeLibrary") | ||
* @param instanceKey A unique key to identify this specific instance of the component. | ||
* For shared memory consumers, this should be a common key across | ||
* all instances to avoid double counting. | ||
*/ | ||
case class UnmanagedMemoryConsumerId( | ||
componentType: String, | ||
instanceKey: String | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indentation, @ericm-db ? |
||
|
||
/** | ||
* Interface for components that consume memory outside of Spark's unified memory management. | ||
* | ||
* Components implementing this trait can register themselves with the memory manager | ||
* to have their memory usage tracked and factored into memory allocation decisions. | ||
* This helps prevent OOM errors when unmanaged components use significant memory. | ||
* | ||
* Examples of unmanaged memory consumers: | ||
* - RocksDB state stores in structured streaming | ||
* - Native libraries with custom memory allocation | ||
* - Off-heap caches managed outside of Spark | ||
*/ | ||
trait UnmanagedMemoryConsumer { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we move this into a separate file, |
||
/** | ||
* Returns the unique identifier for this memory consumer. | ||
* The identifier is used to track and manage the consumer in the memory tracking system. | ||
*/ | ||
def unmanagedMemoryConsumerId: UnmanagedMemoryConsumerId | ||
|
||
/** | ||
* Returns the memory mode (ON_HEAP or OFF_HEAP) that this consumer uses. | ||
* This is used to ensure unmanaged memory usage only affects the correct memory pool. | ||
*/ | ||
def memoryMode: MemoryMode | ||
|
||
/** | ||
* Returns the current memory usage in bytes. | ||
* | ||
* This method is called periodically by the memory polling mechanism to track | ||
* memory usage over time. Implementations should return the current total memory | ||
* consumed by this component. | ||
* | ||
* @return Current memory usage in bytes. Should return 0 if no memory is currently used. | ||
* Return -1L to indicate this consumer is no longer active and should be | ||
* automatically removed from tracking. | ||
* @throws Exception if memory usage cannot be determined. The polling mechanism | ||
* will handle exceptions gracefully and log warnings. | ||
*/ | ||
def getMemBytesUsed: Long | ||
} |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be safe and avoid a regression, shall we start with
0
by default, @ericm-db , @anishshri-db, @gatorsmile ?