|
10 | 10 | from unittest import mock
|
11 | 11 |
|
12 | 12 | from ads.jobs import DataScienceJob, DataScienceJobRun, PyTorchDistributedRuntime
|
| 13 | +from ads.jobs.builders.infrastructure.dsc_job_runtime import ( |
| 14 | + MULTI_NODE_JOB_SUPPORT, |
| 15 | +) |
13 | 16 | from ads.jobs.builders.infrastructure.dsc_job_runtime import (
|
14 | 17 | PyTorchDistributedRuntimeHandler as Handler,
|
15 | 18 | )
|
@@ -133,23 +136,26 @@ def test_create_job_runs(self, patched_run, *args):
|
133 | 136 | self.assertIsInstance(main_run, DataScienceJobRun)
|
134 | 137 | self.assertEqual(main_run.id, test_ocid)
|
135 | 138 | kwarg_list = [call_args.kwargs for call_args in patched_run.call_args_list]
|
136 |
| - self.assertEqual( |
137 |
| - kwarg_list, |
138 |
| - [ |
139 |
| - { |
140 |
| - "display_name": "None-0", |
141 |
| - "environment_variables": {"NODE_RANK": "0", "NODE_COUNT": "2"}, |
142 |
| - }, |
143 |
| - { |
144 |
| - "display_name": "None-1", |
145 |
| - "environment_variables": { |
146 |
| - "NODE_RANK": "1", |
147 |
| - "NODE_COUNT": "2", |
148 |
| - "MAIN_JOB_RUN_OCID": test_ocid, |
| 139 | + if MULTI_NODE_JOB_SUPPORT: |
| 140 | + self.assertEqual(kwarg_list, [{}]) |
| 141 | + else: |
| 142 | + self.assertEqual( |
| 143 | + kwarg_list, |
| 144 | + [ |
| 145 | + { |
| 146 | + "display_name": "None-0", |
| 147 | + "environment_variables": {"NODE_RANK": "0", "NODE_COUNT": "2"}, |
149 | 148 | },
|
150 |
| - }, |
151 |
| - ], |
152 |
| - ) |
| 149 | + { |
| 150 | + "display_name": "None-1", |
| 151 | + "environment_variables": { |
| 152 | + "NODE_RANK": "1", |
| 153 | + "NODE_COUNT": "2", |
| 154 | + "MAIN_JOB_RUN_OCID": test_ocid, |
| 155 | + }, |
| 156 | + }, |
| 157 | + ], |
| 158 | + ) |
153 | 159 |
|
154 | 160 | @mock.patch.dict(
|
155 | 161 | os.environ, {utils.CONST_ENV_INPUT_MAPPINGS: json.dumps({INPUT_SRC: INPUT_DST})}
|
|
0 commit comments