@@ -368,8 +368,8 @@ def test_schedule_with_dependencies(slurm_scheduler, slurm_executor):
368368 mock_tunnel .run .assert_called_once ()
369369
370370
371- def test_ray_template_env_var (slurm_scheduler , slurm_executor ):
372- """Test that NEMO_RUN_SLURM_RAY_TEMPLATE environment variable selects the correct template."""
371+ def test_ray_template_executor (slurm_scheduler , slurm_executor , temp_dir ):
372+ """Test that executor.ray_template selects the correct template."""
373373 from nemo_run .config import USE_WITH_RAY_CLUSTER_KEY
374374 from nemo_run .run .ray .slurm import SlurmRayRequest
375375
@@ -387,20 +387,26 @@ def test_ray_template_env_var(slurm_scheduler, slurm_executor):
387387 ):
388388 slurm_scheduler .tunnel = mock .MagicMock ()
389389
390- # Test default template name
390+ # Test default template name (ray.sub.j2)
391+ assert slurm_executor .ray_template == "ray.sub.j2"
391392 with mock .patch ("nemo_run.core.execution.utils.fill_template" ) as mock_fill :
392393 mock_fill .return_value = "#!/bin/bash\n # Mock script"
393394 dryrun_info = slurm_scheduler ._submit_dryrun (app_def , slurm_executor )
394395 assert isinstance (dryrun_info .request , SlurmRayRequest )
395396 assert dryrun_info .request .template_name == "ray.sub.j2"
396397
397- # Test custom template name via environment variable
398- with (
399- mock .patch .dict (os .environ , {"NEMO_RUN_SLURM_RAY_TEMPLATE" : "ray_enroot.sub.j2" }),
400- mock .patch ("nemo_run.core.execution.utils.fill_template" ) as mock_fill ,
401- ):
398+ # Test custom template name via executor
399+ custom_executor = SlurmExecutor (
400+ account = "test_account" ,
401+ job_dir = temp_dir ,
402+ nodes = 1 ,
403+ ntasks_per_node = 1 ,
404+ tunnel = LocalTunnel (job_dir = temp_dir ),
405+ ray_template = "ray_enroot.sub.j2" ,
406+ )
407+ with mock .patch ("nemo_run.core.execution.utils.fill_template" ) as mock_fill :
402408 mock_fill .return_value = "#!/bin/bash\n # Mock script"
403- dryrun_info = slurm_scheduler ._submit_dryrun (app_def , slurm_executor )
409+ dryrun_info = slurm_scheduler ._submit_dryrun (app_def , custom_executor )
404410 assert isinstance (dryrun_info .request , SlurmRayRequest )
405411 assert dryrun_info .request .template_name == "ray_enroot.sub.j2"
406412
0 commit comments