@@ -278,6 +278,19 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str:
278278
279279 return " " .join (_srun_flags )
280280
281+ def get_command_srun_args () -> str :
282+ if (
283+ self .executor .run_as_group
284+ and self .executor .heterogeneous
285+ and self .executor .resource_group
286+ and self .executor .resource_group [0 ].srun_args is not None
287+ ):
288+ command_srun_args = self .executor .resource_group [0 ].srun_args
289+ else :
290+ command_srun_args = self .executor .srun_args or []
291+
292+ return " " .join (shlex .quote (arg ) for arg in command_srun_args )
293+
281294 ray_log_prefix = job_details .ray_log_prefix
282295 vars_to_fill = {
283296 "sbatch_flags" : sbatch_flags ,
@@ -296,6 +309,7 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str:
296309 "ray_log_prefix" : ray_log_prefix ,
297310 "heterogeneous" : self .executor .heterogeneous ,
298311 "resource_group" : self .executor .resource_group if self .executor .heterogeneous else [],
312+ "command_srun_args" : get_command_srun_args (),
299313 }
300314
301315 if self .command_groups :
@@ -1257,9 +1271,9 @@ def start(
12571271 if isinstance (self .executor .tunnel , SSHTunnel ):
12581272 # Rsync workdir honouring .gitignore
12591273 self .executor .tunnel .connect ()
1260- assert self . executor . tunnel . session is not None , (
1261- "Tunnel session is not connected"
1262- )
1274+ assert (
1275+ self . executor . tunnel . session is not None
1276+ ), "Tunnel session is not connected"
12631277 rsync (
12641278 self .executor .tunnel .session ,
12651279 workdir ,
@@ -1314,9 +1328,9 @@ def start(
13141328
13151329 if isinstance (self .executor .tunnel , SSHTunnel ):
13161330 self .executor .tunnel .connect ()
1317- assert self . executor . tunnel . session is not None , (
1318- "Tunnel session is not connected"
1319- )
1331+ assert (
1332+ self . executor . tunnel . session is not None
1333+ ), "Tunnel session is not connected"
13201334 rsync (
13211335 self .executor .tunnel .session ,
13221336 os .path .join (local_code_extraction_path , "" ),
0 commit comments