diff --git a/ads/jobs/builders/infrastructure/dsc_job.py b/ads/jobs/builders/infrastructure/dsc_job.py index cad652bbd..91b978f1a 100644 --- a/ads/jobs/builders/infrastructure/dsc_job.py +++ b/ads/jobs/builders/infrastructure/dsc_job.py @@ -1751,6 +1751,7 @@ def is_multi_node_job(runtime): return ( MULTI_NODE_JOB_SUPPORT and isinstance(runtime, MultiNodeRuntime) + and runtime.replica and runtime.replica > 1 ) diff --git a/ads/jobs/builders/infrastructure/dsc_job_runtime.py b/ads/jobs/builders/infrastructure/dsc_job_runtime.py index a5099f32e..880a413c5 100644 --- a/ads/jobs/builders/infrastructure/dsc_job_runtime.py +++ b/ads/jobs/builders/infrastructure/dsc_job_runtime.py @@ -365,6 +365,11 @@ def _get_node_group(self, dsc_job): dsc_job, "job_node_configuration_details.job_node_group_configuration_details_list", ) + if node_groups is None: + node_groups = get_value( + dsc_job, + "job_node_configuration_details.jobNodeGroupConfigurationDetailsList", + ) if node_groups and len(node_groups) == 1: return node_groups[0] return None @@ -373,6 +378,7 @@ def _get_replica(self, dsc_job, envs): node_group = self._get_node_group(dsc_job) if node_group: replica = get_value(node_group, "replicas") + envs.pop(self.CONST_NODE_COUNT, None) elif not envs: replica = None elif self.CONST_WORKER_COUNT in envs: @@ -399,7 +405,9 @@ def _extract_envs(self, dsc_job): env_attr = "job_configuration_details.environment_variables" node_group = self._get_node_group(dsc_job) if node_group: - envs = get_value(node_group, env_attr) + envs = get_value(node_group, env_attr) or get_value( + node_group, "jobConfigurationDetails.environment_variables" + ) else: envs = get_value(dsc_job, env_attr) if envs: diff --git a/ads/pipeline/ads_pipeline.py b/ads/pipeline/ads_pipeline.py index 73b247876..efa575195 100644 --- a/ads/pipeline/ads_pipeline.py +++ b/ads/pipeline/ads_pipeline.py @@ -1728,15 +1728,19 @@ def __step_details(self, pipeline_details: Dict) -> list: def __step_infrastructure_configuration_details(self, step) -> dict: step_infrastructure_configuration_details = {} - step_infrastructure_configuration_details[ - "blockStorageSizeInGBs" - ] = step.infrastructure.block_storage_size - step_infrastructure_configuration_details[ - "shapeName" - ] = step.infrastructure.shape_name - step_infrastructure_configuration_details[ - "shapeConfigDetails" - ] = step.infrastructure.shape_config_details + step_infrastructure_configuration_details["blockStorageSizeInGBs"] = ( + step.infrastructure.block_storage_size + ) + step_infrastructure_configuration_details["shapeName"] = ( + step.infrastructure.shape_name + ) + step_infrastructure_configuration_details["shapeConfigDetails"] = ( + step.infrastructure.shape_config_details + ) + if getattr(step.infrastructure, "subnet_id", ""): + step_infrastructure_configuration_details["subnetId"] = ( + step.infrastructure.subnet_id + ) return step_infrastructure_configuration_details def __step_configuration_details(self, pipeline_details: Dict, step) -> dict: diff --git a/tests/unitary/default_setup/jobs/test_jobs_pytorch_ddp.py b/tests/unitary/default_setup/jobs/test_jobs_pytorch_ddp.py index 838278a79..49af200e6 100644 --- a/tests/unitary/default_setup/jobs/test_jobs_pytorch_ddp.py +++ b/tests/unitary/default_setup/jobs/test_jobs_pytorch_ddp.py @@ -10,6 +10,9 @@ from unittest import mock from ads.jobs import DataScienceJob, DataScienceJobRun, PyTorchDistributedRuntime +from ads.jobs.builders.infrastructure.dsc_job_runtime import ( + MULTI_NODE_JOB_SUPPORT, +) from ads.jobs.builders.infrastructure.dsc_job_runtime import ( PyTorchDistributedRuntimeHandler as Handler, ) @@ -133,23 +136,26 @@ def test_create_job_runs(self, patched_run, *args): self.assertIsInstance(main_run, DataScienceJobRun) self.assertEqual(main_run.id, test_ocid) kwarg_list = [call_args.kwargs for call_args in patched_run.call_args_list] - self.assertEqual( - kwarg_list, - [ - { - "display_name": "None-0", - "environment_variables": {"NODE_RANK": "0", "NODE_COUNT": "2"}, - }, - { - "display_name": "None-1", - "environment_variables": { - "NODE_RANK": "1", - "NODE_COUNT": "2", - "MAIN_JOB_RUN_OCID": test_ocid, + if MULTI_NODE_JOB_SUPPORT: + self.assertEqual(kwarg_list, [{}]) + else: + self.assertEqual( + kwarg_list, + [ + { + "display_name": "None-0", + "environment_variables": {"NODE_RANK": "0", "NODE_COUNT": "2"}, }, - }, - ], - ) + { + "display_name": "None-1", + "environment_variables": { + "NODE_RANK": "1", + "NODE_COUNT": "2", + "MAIN_JOB_RUN_OCID": test_ocid, + }, + }, + ], + ) @mock.patch.dict( os.environ, {utils.CONST_ENV_INPUT_MAPPINGS: json.dumps({INPUT_SRC: INPUT_DST})}