Skip to content

Commit 28bff07

Browse files
committed
Support accelerator directive for local executor
Signed-off-by: Ben Sherman <[email protected]>
1 parent bfa67ca commit 28bff07

File tree

7 files changed

+261
-8
lines changed

7 files changed

+261
-8
lines changed

docs/executor.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ The `local` executor is useful for developing and testing a pipeline script on y
227227

228228
Resource requests and other job characteristics can be controlled via the following process directives:
229229

230+
- {ref}`process-accelerator`
230231
- {ref}`process-cpus`
231232
- {ref}`process-memory`
232233
- {ref}`process-time`
@@ -241,6 +242,25 @@ The local executor supports two types of tasks:
241242
- Script tasks (processes with a `script` or `shell` block) - executed via a Bash wrapper
242243
- Native tasks (processes with an `exec` block) - executed directly in the JVM.
243244

245+
(local-accelerators)=
246+
247+
### Accelerators
248+
249+
:::{versionadded} 25.10.0
250+
:::
251+
252+
The local executor can use the `accelerator` directive to allocate accelerators such as GPUs. To use accelerators, set the corresponding environment variable:
253+
254+
- `CUDA_VISIBLE_DEVICES` for [NVIDIA CUDA](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-environment-variables) applications
255+
256+
- `HIP_VISIBLE_DEVICES` for [HIP](https://rocm.docs.amd.com/projects/HIP/en/docs-develop/reference/env_variables.html) applications
257+
258+
- `ROCR_VISIBLE_DEVICES` for [AMD ROCm](https://rocm.docs.amd.com/en/latest/conceptual/gpu-isolation.html) applications
259+
260+
Set the environment variable to a comma-separated list of device IDs that you want Nextflow to access. Nextflow uses the same environment variable to allocate accelerators for each task that requests them.
261+
262+
For example, to use all GPUs on a node with four NVIDIA GPUs, set `CUDA_VISIBLE_DEVICES` to `0,1,2,3`. If four tasks each request one GPU, they will be executed with `CUDA_VISIBLE_DEVICES` set to `0`, `1`, `2`, and `3`, respectively.
263+
244264
(lsf-executor)=
245265

246266
## LSF

docs/migrations/25-10.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ workflow {
2929

3030
This syntax is simpler and easier to use with the {ref}`strict syntax <strict-syntax-page>`. See {ref}`workflow-handlers` for details.
3131

32+
<h3>GPU scheduling for local executor</h3>
33+
34+
The local executor can now schedule GPUs using the `accelerator` directive. This feature is useful when running Nextflow on a single machine with multiple GPUs.
35+
36+
See {ref}`local-accelerators` for details.
37+
3238
## Breaking changes
3339

3440
- The AWS Java SDK used by Nextflow was upgraded from v1 to v2, which introduced some breaking changes to the `aws.client` config options. See {ref}`the guide <aws-java-sdk-v2-page>` for details.

modules/nextflow/src/main/groovy/nextflow/executor/local/LocalTaskHandler.groovy

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ class LocalTaskHandler extends TaskHandler implements FusionAwareTask {
8080

8181
private volatile TaskResult result
8282

83+
String acceleratorEnv
84+
85+
List<String> acceleratorIds
86+
8387
LocalTaskHandler(TaskRun task, LocalExecutor executor) {
8488
super(task)
8589
// create the task handler
@@ -142,11 +146,13 @@ class LocalTaskHandler extends TaskHandler implements FusionAwareTask {
142146
final workDir = task.workDir.toFile()
143147
final logFile = new File(workDir, TaskRun.CMD_LOG)
144148

145-
return new ProcessBuilder()
149+
final pb = new ProcessBuilder()
146150
.redirectErrorStream(true)
147151
.redirectOutput(logFile)
148152
.directory(workDir)
149153
.command(cmd)
154+
prepareAccelerators(pb)
155+
return pb
150156
}
151157

152158
protected ProcessBuilder fusionProcessBuilder() {
@@ -162,10 +168,18 @@ class LocalTaskHandler extends TaskHandler implements FusionAwareTask {
162168

163169
final logPath = Files.createTempFile('nf-task','.log')
164170

165-
return new ProcessBuilder()
171+
final pb = new ProcessBuilder()
166172
.redirectErrorStream(true)
167173
.redirectOutput(logPath.toFile())
168174
.command(List.of('sh','-c', cmd))
175+
prepareAccelerators(pb)
176+
return pb
177+
}
178+
179+
protected void prepareAccelerators(ProcessBuilder pb) {
180+
if( !acceleratorEnv )
181+
return
182+
pb.environment().put(acceleratorEnv, acceleratorIds.join(','))
169183
}
170184

171185
protected ProcessBuilder createLaunchProcessBuilder() {
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright 2013-2024, Seqera Labs
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package nextflow.processor
18+
19+
import groovy.transform.CompileStatic
20+
import nextflow.SysEnv
21+
import nextflow.util.TrackingSemaphore
22+
23+
/**
24+
* Specialized semaphore that keeps track of accelerators by
25+
* id. The id can be an integer or a UUID.
26+
*
27+
* @author Ben Sherman <[email protected]>
28+
*/
29+
@CompileStatic
30+
class AcceleratorTracker {
31+
32+
private static final List<String> DEVICE_ENV_NAMES = [
33+
'CUDA_VISIBLE_DEVICES',
34+
'HIP_VISIBLE_DEVICES',
35+
'ROCR_VISIBLE_DEVICES'
36+
]
37+
38+
static AcceleratorTracker create() {
39+
return DEVICE_ENV_NAMES.stream()
40+
.filter(name -> SysEnv.containsKey(name))
41+
.map((name) -> {
42+
final ids = SysEnv.get(name).tokenize(',')
43+
return new AcceleratorTracker(name, ids)
44+
})
45+
.findFirst().orElse(new AcceleratorTracker())
46+
}
47+
48+
private final String name
49+
private final TrackingSemaphore semaphore
50+
51+
AcceleratorTracker(String name, List<String> ids) {
52+
this.name = name
53+
this.semaphore = new TrackingSemaphore(ids)
54+
}
55+
56+
AcceleratorTracker() {
57+
this(null, [])
58+
}
59+
60+
String name() {
61+
return name
62+
}
63+
64+
int total() {
65+
return semaphore.totalPermits()
66+
}
67+
68+
int available() {
69+
return semaphore.availablePermits()
70+
}
71+
72+
List<String> acquire(int permits) {
73+
return semaphore.acquire(permits)
74+
}
75+
76+
void release(List<String> ids) {
77+
semaphore.release(ids)
78+
}
79+
80+
}

modules/nextflow/src/main/groovy/nextflow/processor/LocalPollingMonitor.groovy

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import groovy.util.logging.Slf4j
2424
import nextflow.Session
2525
import nextflow.executor.ExecutorConfig
2626
import nextflow.exception.ProcessUnrecoverableException
27+
import nextflow.executor.local.LocalTaskHandler
2728
import nextflow.util.Duration
2829
import nextflow.util.MemoryUnit
2930

@@ -59,6 +60,11 @@ class LocalPollingMonitor extends TaskPollingMonitor {
5960
*/
6061
private final long maxMemory
6162

63+
/**
64+
* Tracks the total and available accelerators in the system
65+
*/
66+
private AcceleratorTracker acceleratorTracker
67+
6268
/**
6369
* Create the task polling monitor with the provided named parameters object.
6470
* <p>
@@ -76,6 +82,7 @@ class LocalPollingMonitor extends TaskPollingMonitor {
7682
super(params)
7783
this.availCpus = maxCpus = params.cpus as int
7884
this.availMemory = maxMemory = params.memory as long
85+
this.acceleratorTracker = AcceleratorTracker.create()
7986
assert availCpus>0, "Local avail `cpus` attribute cannot be zero"
8087
assert availMemory>0, "Local avail `memory` attribute cannot zero"
8188
}
@@ -154,6 +161,16 @@ class LocalPollingMonitor extends TaskPollingMonitor {
154161
handler.task.getConfig()?.getMemory()?.toBytes() ?: 1L
155162
}
156163

164+
/**
165+
* @param handler
166+
* A {@link TaskHandler} instance
167+
* @return
168+
* The number of accelerators requested to execute the specified task
169+
*/
170+
private static int accelerators(TaskHandler handler) {
171+
handler.task.getConfig()?.getAccelerator()?.getRequest() ?: 0
172+
}
173+
157174
/**
158175
* Determines if a task can be submitted for execution checking if the resources required
159176
* (cpus and memory) match the amount of avail resource
@@ -179,9 +196,13 @@ class LocalPollingMonitor extends TaskPollingMonitor {
179196
if( taskMemory>maxMemory)
180197
throw new ProcessUnrecoverableException("Process requirement exceeds available memory -- req: ${new MemoryUnit(taskMemory)}; avail: ${new MemoryUnit(maxMemory)}")
181198

182-
final result = super.canSubmit(handler) && taskCpus <= availCpus && taskMemory <= availMemory
199+
final taskAccelerators = accelerators(handler)
200+
if( taskAccelerators > acceleratorTracker.total() )
201+
throw new ProcessUnrecoverableException("Process requirement exceeds available accelerators -- req: $taskAccelerators; avail: ${acceleratorTracker.total()}")
202+
203+
final result = super.canSubmit(handler) && taskCpus <= availCpus && taskMemory <= availMemory && taskAccelerators <= acceleratorTracker.available()
183204
if( !result && log.isTraceEnabled( ) ) {
184-
log.trace "Task `${handler.task.name}` cannot be scheduled -- taskCpus: $taskCpus <= availCpus: $availCpus && taskMemory: ${new MemoryUnit(taskMemory)} <= availMemory: ${new MemoryUnit(availMemory)}"
205+
log.trace "Task `${handler.task.name}` cannot be scheduled -- taskCpus: $taskCpus <= availCpus: $availCpus && taskMemory: ${new MemoryUnit(taskMemory)} <= availMemory: ${new MemoryUnit(availMemory)} && taskAccelerators: $taskAccelerators <= availAccelerators: ${acceleratorTracker.available()}"
185206
}
186207
return result
187208
}
@@ -194,9 +215,16 @@ class LocalPollingMonitor extends TaskPollingMonitor {
194215
*/
195216
@Override
196217
protected void submit(TaskHandler handler) {
197-
super.submit(handler)
198218
availCpus -= cpus(handler)
199219
availMemory -= mem(handler)
220+
221+
final taskAccelerators = accelerators(handler)
222+
if ( taskAccelerators > 0 ) {
223+
((LocalTaskHandler) handler).acceleratorEnv = acceleratorTracker.name()
224+
((LocalTaskHandler) handler).acceleratorIds = acceleratorTracker.acquire(taskAccelerators)
225+
}
226+
227+
super.submit(handler)
200228
}
201229

202230
/**
@@ -209,11 +237,13 @@ class LocalPollingMonitor extends TaskPollingMonitor {
209237
* {@code true} when the task is successfully removed from polling queue,
210238
* {@code false} otherwise
211239
*/
240+
@Override
212241
protected boolean remove(TaskHandler handler) {
213242
final result = super.remove(handler)
214243
if( result ) {
215244
availCpus += cpus(handler)
216245
availMemory += mem(handler)
246+
acceleratorTracker.release(((LocalTaskHandler) handler).acceleratorIds ?: Collections.<String>emptyList())
217247
}
218248
return result
219249
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Copyright 2013-2024, Seqera Labs
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package nextflow.util
18+
19+
import java.util.concurrent.Semaphore
20+
21+
import groovy.transform.CompileStatic
22+
23+
/**
24+
* Specialized semaphore that keeps track of which permits
25+
* are being used.
26+
*
27+
* @author Ben Sherman <[email protected]>
28+
*/
29+
@CompileStatic
30+
class TrackingSemaphore {
31+
private final Semaphore semaphore
32+
private final Map<String,Boolean> availIds
33+
34+
TrackingSemaphore(List<String> ids) {
35+
semaphore = new Semaphore(ids.size())
36+
availIds = new HashMap<>(ids.size())
37+
for( final id : ids )
38+
availIds.put(id, true)
39+
}
40+
41+
int totalPermits() {
42+
return availIds.size()
43+
}
44+
45+
int availablePermits() {
46+
return semaphore.availablePermits()
47+
}
48+
49+
List<String> acquire(int permits) {
50+
semaphore.acquire(permits)
51+
final result = new ArrayList<String>(permits)
52+
for( final entry : availIds.entrySet() ) {
53+
if( entry.getValue() ) {
54+
entry.setValue(false)
55+
result.add(entry.getKey())
56+
}
57+
if( result.size() == permits )
58+
break
59+
}
60+
return result
61+
}
62+
63+
void release(List<String> ids) {
64+
semaphore.release(ids.size())
65+
for( final id : ids )
66+
availIds.put(id, true)
67+
}
68+
69+
}

0 commit comments

Comments
 (0)