Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions core/src/main/scala/org/apache/spark/internal/config/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,16 @@ package object config {
.doubleConf
.createWithDefault(0.6)

private[spark] val UNMANAGED_MEMORY_POLLING_INTERVAL =
ConfigBuilder("spark.memory.unmanagedMemoryPollingInterval")
.doc("Interval for polling unmanaged memory users to track their memory usage. " +
"Unmanaged memory users are components that manage their own memory outside of " +
"Spark's core memory management, such as RocksDB for Streaming State Store. " +
"Setting this to 0 disables unmanaged memory polling.")
.version("4.1.0")
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("1s")
Copy link
Member

@dongjoon-hyun dongjoon-hyun Aug 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be safe and avoid a regression, shall we start with 0 by default, @ericm-db , @anishshri-db, @gatorsmile ?


private[spark] val STORAGE_UNROLL_MEMORY_THRESHOLD =
ConfigBuilder("spark.storage.unrollMemoryThreshold")
.doc("Initial memory to request before unrolling any block")
Expand Down
305 changes: 299 additions & 6 deletions core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,19 @@

package org.apache.spark.memory

import java.util.concurrent.{ConcurrentHashMap, ScheduledExecutorService, TimeUnit}
import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong}

import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal

import org.apache.spark.{SparkConf, SparkIllegalArgumentException}
import org.apache.spark.internal.{config, MDC}
import org.apache.spark.internal.{config, Logging, LogKeys, MDC}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.internal.config.Tests._
import org.apache.spark.internal.config.UNMANAGED_MEMORY_POLLING_INTERVAL
import org.apache.spark.storage.BlockId
import org.apache.spark.util.{ThreadUtils, Utils}

/**
* A [[MemoryManager]] that enforces a soft boundary between execution and storage such that
Expand Down Expand Up @@ -56,7 +64,47 @@ private[spark] class UnifiedMemoryManager(
conf,
numCores,
onHeapStorageRegionSize,
maxHeapMemory - onHeapStorageRegionSize) {
maxHeapMemory - onHeapStorageRegionSize) with Logging {

/**
* Unmanaged memory tracking infrastructure.
*
* Unmanaged memory refers to memory consumed by components that manage their own memory
* outside of Spark's unified memory management system. Examples include:
* - RocksDB state stores used in structured streaming
* - Native libraries with their own memory management
* - Off-heap caches managed by unmanaged systems
*
* We track this memory to:
* 1. Provide visibility into total memory usage on executors
* 2. Prevent OOM errors by accounting for it in memory allocation decisions
* 3. Enable better debugging and monitoring of memory-intensive applications
*
* The polling mechanism periodically queries registered unmanaged memory consumers
* to detect inactive consumers and handle cleanup.
*/
// Configuration for polling interval (in milliseconds)
private val unmanagedMemoryPollingIntervalMs = conf.get(UNMANAGED_MEMORY_POLLING_INTERVAL)
// Initialize background polling if enabled
if (unmanagedMemoryPollingIntervalMs > 0) {
UnifiedMemoryManager.startPollingIfNeeded(unmanagedMemoryPollingIntervalMs)
}

/**
* Get the current unmanaged memory usage in bytes for a specific memory mode.
* @param memoryMode The memory mode (ON_HEAP or OFF_HEAP) to get usage for
* @return The current unmanaged memory usage in bytes
*/
private def getUnmanagedMemoryUsed(memoryMode: MemoryMode): Long = {
// Only consider unmanaged memory if polling is enabled
if (unmanagedMemoryPollingIntervalMs <= 0) {
return 0L
}
memoryMode match {
case MemoryMode.ON_HEAP => UnifiedMemoryManager.unmanagedOnHeapUsed.get()
case MemoryMode.OFF_HEAP => UnifiedMemoryManager.unmanagedOffHeapUsed.get()
}
}

private def assertInvariants(): Unit = {
assert(onHeapExecutionMemoryPool.poolSize + onHeapStorageMemoryPool.poolSize == maxHeapMemory)
Expand Down Expand Up @@ -140,9 +188,15 @@ private[spark] class UnifiedMemoryManager(
* in execution memory allocation across tasks, Otherwise, a task may occupy more than
* its fair share of execution memory, mistakenly thinking that other tasks can acquire
* the portion of storage memory that cannot be evicted.
*
* This also factors in unmanaged memory usage to ensure we don't over-allocate memory
* when unmanaged components are consuming significant memory.
*/
def computeMaxExecutionPoolSize(): Long = {
maxMemory - math.min(storagePool.memoryUsed, storageRegionSize)
val unmanagedMemory = getUnmanagedMemoryUsed(memoryMode)
val availableMemory = maxMemory - math.min(storagePool.memoryUsed, storageRegionSize)
// Reduce available memory by unmanaged memory usage to prevent over-allocation
math.max(0L, availableMemory - unmanagedMemory)
}

executionPool.acquireMemory(
Expand All @@ -165,11 +219,21 @@ private[spark] class UnifiedMemoryManager(
offHeapStorageMemoryPool,
maxOffHeapStorageMemory)
}
if (numBytes > maxMemory) {

// Factor in unmanaged memory usage for the specific memory mode
val unmanagedMemory = getUnmanagedMemoryUsed(memoryMode)
val effectiveMaxMemory = math.max(0L, maxMemory - unmanagedMemory)

if (numBytes > effectiveMaxMemory) {
// Fail fast if the block simply won't fit
logInfo(log"Will not store ${MDC(BLOCK_ID, blockId)} as the required space" +
log" (${MDC(NUM_BYTES, numBytes)} bytes) exceeds our" +
log" memory limit (${MDC(NUM_BYTES_MAX, maxMemory)} bytes)")
log" memory limit (${MDC(NUM_BYTES_MAX, effectiveMaxMemory)} bytes)" +
(if (unmanagedMemory > 0) {
log" (unmanaged memory usage: ${MDC(NUM_BYTES, unmanagedMemory)} bytes)"
} else {
log""
}))
return false
}
if (numBytes > storagePool.memoryFree) {
Expand All @@ -191,14 +255,189 @@ private[spark] class UnifiedMemoryManager(
}
}

object UnifiedMemoryManager {
object UnifiedMemoryManager extends Logging {

// Set aside a fixed amount of memory for non-storage, non-execution purposes.
// This serves a function similar to `spark.memory.fraction`, but guarantees that we reserve
// sufficient memory for the system even for small heaps. E.g. if we have a 1GB JVM, then
// the memory used for execution and storage will be (1024 - 300) * 0.6 = 434MB by default.
private val RESERVED_SYSTEM_MEMORY_BYTES = 300 * 1024 * 1024

private val unmanagedMemoryConsumers =
new ConcurrentHashMap[UnmanagedMemoryConsumerId, UnmanagedMemoryConsumer]

// Cached unmanaged memory usage values updated by polling
private val unmanagedOnHeapUsed = new AtomicLong(0L)
private val unmanagedOffHeapUsed = new AtomicLong(0L)

// Atomic flag to ensure polling is only started once per JVM
private val pollingStarted = new AtomicBoolean(false)

/**
* Register an unmanaged memory consumer to track its memory usage.
*
* Unmanaged memory consumers are components that manage their own memory outside
* of Spark's unified memory management system. By registering, their memory usage
* will be periodically polled and factored into Spark's memory allocation decisions.
*
* @param unmanagedMemoryConsumer The consumer to register for memory tracking
*/
def registerUnmanagedMemoryConsumer(
unmanagedMemoryConsumer: UnmanagedMemoryConsumer): Unit = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indentation, @ericm-db ?

val id = unmanagedMemoryConsumer.unmanagedMemoryConsumerId
unmanagedMemoryConsumers.put(id, unmanagedMemoryConsumer)
}

/**
* Unregister an unmanaged memory consumer.
* This should be called when a component is shutting down to prevent memory leaks
* and ensure accurate memory tracking.
*
* @param unmanagedMemoryConsumer The consumer to unregister. Only used in tests
*/
private[spark] def unregisterUnmanagedMemoryConsumer(
unmanagedMemoryConsumer: UnmanagedMemoryConsumer): Unit = {
val id = unmanagedMemoryConsumer.unmanagedMemoryConsumerId
unmanagedMemoryConsumers.remove(id)
}


/**
* Get the current memory usage in bytes for a specific component type.
* @param componentType The type of component to filter by (e.g., "RocksDB")
* @return Total memory usage in bytes for the specified component type
*/
def getMemoryByComponentType(componentType: String): Long = {
unmanagedMemoryConsumers.asScala.values.toSeq
.filter(_.unmanagedMemoryConsumerId.componentType == componentType)
.map { memoryUser =>
try {
memoryUser.getMemBytesUsed
} catch {
case e: Exception =>
0L
}
}
.sum
}

/**
* Clear all unmanaged memory users.
* This is useful during executor shutdown or cleanup.
* Since each executor runs in its own JVM, this clears all users for this executor.
*/
def clearUnmanagedMemoryUsers(): Unit = {
unmanagedMemoryConsumers.clear()
// Reset cached values when clearing consumers
unmanagedOnHeapUsed.set(0L)
unmanagedOffHeapUsed.set(0L)
}

// Shared polling infrastructure - only one polling thread per JVM
@volatile private var unmanagedMemoryPoller: ScheduledExecutorService = _

/**
* Start unmanaged memory polling if not already started.
* This ensures only one polling thread is created per JVM, regardless of how many
* UnifiedMemoryManager instances are created.
*/
private[memory] def startPollingIfNeeded(pollingIntervalMs: Long): Unit = {
if (pollingStarted.compareAndSet(false, true)) {
unmanagedMemoryPoller = ThreadUtils.newDaemonSingleThreadScheduledExecutor(
"unmanaged-memory-poller")

val pollingTask = new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
pollUnmanagedMemoryUsers()
}
}

unmanagedMemoryPoller.scheduleAtFixedRate(
pollingTask,
0L, // initial delay
pollingIntervalMs,
TimeUnit.MILLISECONDS)

logInfo(log"Unmanaged memory polling started with interval " +
log"${MDC(LogKeys.TIME, pollingIntervalMs)}ms")
}
}

private def pollUnmanagedMemoryUsers(): Unit = {
val consumers = unmanagedMemoryConsumers.asScala.toMap

// Get memory usage for each consumer, handling failures gracefully
val memoryUsages = consumers.map { case (userId, memoryUser) =>
try {
val memoryUsed = memoryUser.getMemBytesUsed
if (memoryUsed == -1L) {
logDebug(log"Unmanaged memory consumer ${MDC(LogKeys.OBJECT_ID, userId.toString)} " +
log"is no longer active, marking for removal")
(userId, memoryUser, None) // Mark for removal
} else if (memoryUsed < 0L) {
logWarning(log"Invalid memory usage value ${MDC(LogKeys.NUM_BYTES, memoryUsed)} " +
log"from unmanaged memory user ${MDC(LogKeys.OBJECT_ID, userId.toString)}")
(userId, memoryUser, Some(0L)) // Treat as 0
} else {
(userId, memoryUser, Some(memoryUsed))
}
} catch {
case NonFatal(e) =>
logWarning(log"Failed to get memory usage for unmanaged memory user " +
log"${MDC(LogKeys.OBJECT_ID, userId.toString)} ${MDC(LogKeys.EXCEPTION, e)}")
(userId, memoryUser, Some(0L)) // Treat as 0 on error
}
}

// Remove inactive consumers
memoryUsages.filter(_._3.isEmpty).foreach { case (userId, _, _) =>
unmanagedMemoryConsumers.remove(userId)
logInfo(log"Removed inactive unmanaged memory consumer " +
log"${MDC(LogKeys.OBJECT_ID, userId.toString)}")
}
// Calculate total memory usage by mode
val activeUsages = memoryUsages.filter(_._3.isDefined)
val onHeapTotal = activeUsages
.filter(_._2.memoryMode == MemoryMode.ON_HEAP)
.map(_._3.get)
.sum
val offHeapTotal = activeUsages
.filter(_._2.memoryMode == MemoryMode.OFF_HEAP)
.map(_._3.get)
.sum
// Update cached values atomically
unmanagedOnHeapUsed.set(onHeapTotal)
unmanagedOffHeapUsed.set(offHeapTotal)
// Log polling results for monitoring
val totalMemoryUsed = onHeapTotal + offHeapTotal
val numConsumers = activeUsages.size
logDebug(s"Unmanaged memory polling completed: $numConsumers consumers, " +
s"total memory used: ${totalMemoryUsed} bytes " +
s"(on-heap: ${onHeapTotal}, off-heap: ${offHeapTotal})")
}

/**
* Shutdown the unmanaged memory polling thread. Only used in tests
*/
private[spark] def shutdownUnmanagedMemoryPoller(): Unit = {
synchronized {
if (unmanagedMemoryPoller != null) {
unmanagedMemoryPoller.shutdown()
try {
if (!unmanagedMemoryPoller.awaitTermination(5, TimeUnit.SECONDS)) {
unmanagedMemoryPoller.shutdownNow()
}
} catch {
case _: InterruptedException =>
Thread.currentThread().interrupt()
}
unmanagedMemoryPoller = null
pollingStarted.set(false)
logInfo(log"Unmanaged memory poller shutdown complete")
}
}
}

def apply(conf: SparkConf, numCores: Int): UnifiedMemoryManager = {
val maxMemory = getMaxMemory(conf)
new UnifiedMemoryManager(
Expand Down Expand Up @@ -242,3 +481,57 @@ object UnifiedMemoryManager {
(usableMemory * memoryFraction).toLong
}
}

/**
* Identifier for an unmanaged memory consumer.
*
* @param componentType The type of component (e.g., "RocksDB", "NativeLibrary")
* @param instanceKey A unique key to identify this specific instance of the component.
* For shared memory consumers, this should be a common key across
* all instances to avoid double counting.
*/
case class UnmanagedMemoryConsumerId(
componentType: String,
instanceKey: String
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indentation, @ericm-db ?


/**
* Interface for components that consume memory outside of Spark's unified memory management.
*
* Components implementing this trait can register themselves with the memory manager
* to have their memory usage tracked and factored into memory allocation decisions.
* This helps prevent OOM errors when unmanaged components use significant memory.
*
* Examples of unmanaged memory consumers:
* - RocksDB state stores in structured streaming
* - Native libraries with custom memory allocation
* - Off-heap caches managed outside of Spark
*/
trait UnmanagedMemoryConsumer {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we move this into a separate file, UnmanagedMemoryConsumer.scala?

/**
* Returns the unique identifier for this memory consumer.
* The identifier is used to track and manage the consumer in the memory tracking system.
*/
def unmanagedMemoryConsumerId: UnmanagedMemoryConsumerId

/**
* Returns the memory mode (ON_HEAP or OFF_HEAP) that this consumer uses.
* This is used to ensure unmanaged memory usage only affects the correct memory pool.
*/
def memoryMode: MemoryMode

/**
* Returns the current memory usage in bytes.
*
* This method is called periodically by the memory polling mechanism to track
* memory usage over time. Implementations should return the current total memory
* consumed by this component.
*
* @return Current memory usage in bytes. Should return 0 if no memory is currently used.
* Return -1L to indicate this consumer is no longer active and should be
* automatically removed from tracking.
* @throws Exception if memory usage cannot be determined. The polling mechanism
* will handle exceptions gracefully and log warnings.
*/
def getMemBytesUsed: Long
}
Loading