diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index a83ad716d34..d84ffaf5c2a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -792,6 +792,12 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern") .bytesConf(ByteUnit.BYTE) .createWithDefault(8 * 1024 * 1024) + val PROFILE_TASK_LIMIT_PER_STAGE = conf("spark.rapids.profile.taskLimitPerStage") + .doc("Limit the number of tasks to profile per stage. A value <= 0 will profile all tasks.") + .internal() + .integerConf + .createWithDefault(0) + // ENABLE/DISABLE PROCESSING val SQL_ENABLED = conf("spark.rapids.sql.enabled") @@ -2603,6 +2609,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val profileWriteBufferSize: Long = get(PROFILE_WRITE_BUFFER_SIZE) + lazy val profileTaskLimitPerStage: Int = get(PROFILE_TASK_LIMIT_PER_STAGE) + lazy val isSqlEnabled: Boolean = get(SQL_ENABLED) lazy val isSqlExecuteOnGPU: Boolean = get(SQL_MODE).equals("executeongpu") diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/profiler.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/profiler.scala index 924a75a7b65..26e9c8711de 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/profiler.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/profiler.scala @@ -50,10 +50,13 @@ object ProfilerOnExecutor extends Logging { private var isProfileActive = false private var currentContextMethod: Method = null private var getContextMethod: Method = null + private val stageTaskCount = mutable.HashMap[Int, Int]() + private var taskLimit = 0 def init(pluginCtx: PluginContext, conf: RapidsConf): Unit = { require(writer.isEmpty, "Already initialized") timeRanges = conf.profileTimeRangesSeconds.map(parseTimeRanges) + taskLimit = conf.profileTaskLimitPerStage jobRanges = new RangeConfMatcher(conf, RapidsConf.PROFILE_JOBS) stageRanges = new RangeConfMatcher(conf, RapidsConf.PROFILE_STAGES) driverPollMillis = conf.profileDriverPollMillis @@ -119,9 +122,27 @@ object ProfilerOnExecutor extends Logging { val stageId = taskCtx.stageId if (stageRanges.contains(stageId)) { synchronized { - activeStages.add(taskCtx.stageId) - enable() - startPollingDriver() + if (taskLimit <= 0) { + // Unlimited tasks per stage + activeStages.add(taskCtx.stageId) + enable() + startPollingDriver() + } else { + // Limited tasks per stage + if (stageTaskCount.getOrElse(stageId, 0) < taskLimit) { + activeStages.add(taskCtx.stageId) + enable() + startPollingDriver() + } + taskCtx.addTaskCompletionListener[Unit] { _ => + val currentCount = stageTaskCount.getOrElse(stageId, 0) + if (currentCount < taskLimit) { + stageTaskCount(stageId) = currentCount + 1 + } else { + disable() + } + } + } } } }