Skip to content

Commit b5395b8

Browse files
committed
Update test for multi-node job run.
1 parent 1a16cf8 commit b5395b8

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

tests/unitary/default_setup/jobs/test_jobs_pytorch_ddp.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from unittest import mock
1111

1212
from ads.jobs import DataScienceJob, DataScienceJobRun, PyTorchDistributedRuntime
13+
from ads.jobs.builders.infrastructure.dsc_job_runtime import (
14+
MULTI_NODE_JOB_SUPPORT,
15+
)
1316
from ads.jobs.builders.infrastructure.dsc_job_runtime import (
1417
PyTorchDistributedRuntimeHandler as Handler,
1518
)
@@ -133,23 +136,26 @@ def test_create_job_runs(self, patched_run, *args):
133136
self.assertIsInstance(main_run, DataScienceJobRun)
134137
self.assertEqual(main_run.id, test_ocid)
135138
kwarg_list = [call_args.kwargs for call_args in patched_run.call_args_list]
136-
self.assertEqual(
137-
kwarg_list,
138-
[
139-
{
140-
"display_name": "None-0",
141-
"environment_variables": {"NODE_RANK": "0", "NODE_COUNT": "2"},
142-
},
143-
{
144-
"display_name": "None-1",
145-
"environment_variables": {
146-
"NODE_RANK": "1",
147-
"NODE_COUNT": "2",
148-
"MAIN_JOB_RUN_OCID": test_ocid,
139+
if MULTI_NODE_JOB_SUPPORT:
140+
self.assertEqual(kwarg_list, [{}])
141+
else:
142+
self.assertEqual(
143+
kwarg_list,
144+
[
145+
{
146+
"display_name": "None-0",
147+
"environment_variables": {"NODE_RANK": "0", "NODE_COUNT": "2"},
149148
},
150-
},
151-
],
152-
)
149+
{
150+
"display_name": "None-1",
151+
"environment_variables": {
152+
"NODE_RANK": "1",
153+
"NODE_COUNT": "2",
154+
"MAIN_JOB_RUN_OCID": test_ocid,
155+
},
156+
},
157+
],
158+
)
153159

154160
@mock.patch.dict(
155161
os.environ, {utils.CONST_ENV_INPUT_MAPPINGS: json.dumps({INPUT_SRC: INPUT_DST})}

0 commit comments

Comments
 (0)