diff --git a/torchx/schedulers/aws_batch_scheduler.py b/torchx/schedulers/aws_batch_scheduler.py index eab0490cb..3c652b082 100644 --- a/torchx/schedulers/aws_batch_scheduler.py +++ b/torchx/schedulers/aws_batch_scheduler.py @@ -92,6 +92,8 @@ ENV_TORCHX_ROLE_NAME = "TORCHX_ROLE_NAME" +ENV_TORCHX_IMAGE = "TORCHX_IMAGE" + DEFAULT_ROLE_NAME = "node" TAG_TORCHX_VER = "torchx.pytorch.org/version" @@ -506,6 +508,7 @@ def _submit_dryrun(self, app: AppDef, cfg: AWSBatchOpts) -> AppDryRunInfo[BatchJ role = values.apply(role) role.env[ENV_TORCHX_ROLE_IDX] = str(role_idx) role.env[ENV_TORCHX_ROLE_NAME] = str(role.name) + role.env[ENV_TORCHX_IMAGE] = role.image nodes.append( _role_to_node_properties( diff --git a/torchx/schedulers/docker_scheduler.py b/torchx/schedulers/docker_scheduler.py index bed591c72..c35a6a034 100644 --- a/torchx/schedulers/docker_scheduler.py +++ b/torchx/schedulers/docker_scheduler.py @@ -84,6 +84,8 @@ def __repr__(self) -> str: LABEL_ROLE_NAME: str = "torchx.pytorch.org/role-name" LABEL_REPLICA_ID: str = "torchx.pytorch.org/replica-id" +ENV_TORCHX_IMAGE: str = "TORCHX_IMAGE" + NETWORK = "torchx" @@ -279,6 +281,7 @@ def _submit_dryrun(self, app: AppDef, cfg: DockerOpts) -> AppDryRunInfo[DockerJo # configure distributed host envs env["TORCHX_RANK0_HOST"] = rank0_name + env[ENV_TORCHX_IMAGE] = replica_role.image c = DockerContainer( image=replica_role.image, diff --git a/torchx/schedulers/kubernetes_scheduler.py b/torchx/schedulers/kubernetes_scheduler.py index a0582755c..229d23c4e 100644 --- a/torchx/schedulers/kubernetes_scheduler.py +++ b/torchx/schedulers/kubernetes_scheduler.py @@ -399,6 +399,7 @@ def app_to_resource( replica_role = values.apply(role) if role_idx == 0 and replica_id == 0: replica_role.env["TORCHX_RANK0_HOST"] = "localhost" + replica_role.env["TORCHX_IMAGE"] = replica_role.image pod = role_to_pod(name, replica_role, service_account) pod.metadata.labels.update( diff --git a/torchx/schedulers/test/aws_batch_scheduler_test.py b/torchx/schedulers/test/aws_batch_scheduler_test.py index c2a5f65f6..e8ac6a83c 100644 --- a/torchx/schedulers/test/aws_batch_scheduler_test.py +++ b/torchx/schedulers/test/aws_batch_scheduler_test.py @@ -22,6 +22,7 @@ AWSBatchOpts, AWSBatchScheduler, create_scheduler, + ENV_TORCHX_IMAGE, ENV_TORCHX_ROLE_NAME, resource_from_resource_requirements, resource_requirements_from_resource, @@ -252,6 +253,10 @@ def test_submit_dryrun(self) -> None: {"name": "FOO", "value": "bar"}, {"name": "TORCHX_ROLE_IDX", "value": "0"}, {"name": "TORCHX_ROLE_NAME", "value": "trainer"}, + { + "name": "TORCHX_IMAGE", + "value": "pytorch/torchx:latest", + }, ], "privileged": False, "resourceRequirements": [ @@ -456,7 +461,11 @@ def _mock_scheduler_running_job(self) -> AWSBatchScheduler: { "name": ENV_TORCHX_ROLE_NAME, "value": "echo", - } + }, + { + "name": ENV_TORCHX_IMAGE, + "value": "pytorch/torchx:latest", + }, ], }, } diff --git a/torchx/schedulers/test/docker_scheduler_test.py b/torchx/schedulers/test/docker_scheduler_test.py index 5b684f3a1..869e14ce5 100644 --- a/torchx/schedulers/test/docker_scheduler_test.py +++ b/torchx/schedulers/test/docker_scheduler_test.py @@ -100,6 +100,7 @@ def test_submit_dryrun(self) -> None: "environment": { "FOO": "bar", "TORCHX_RANK0_HOST": "app_name_42-trainer-0", + "TORCHX_IMAGE": "pytorch/torchx:latest", }, "labels": { "torchx.pytorch.org/app-id": "app_name_42", @@ -190,6 +191,7 @@ def test_copy_env(self) -> None: "FOO_1": "f1", "BAR_1": "b1", "TORCHX_RANK0_HOST": "app_name_42-trainer-0", + "TORCHX_IMAGE": "pytorch/torchx:latest", }, ) @@ -205,6 +207,7 @@ def test_env(self) -> None: "FOO": "bar", "FOO_1": "BAR_1", "TORCHX_RANK0_HOST": "app_name_42-trainer-0", + "TORCHX_IMAGE": "pytorch/torchx:latest", }, ) diff --git a/torchx/schedulers/test/kubernetes_scheduler_test.py b/torchx/schedulers/test/kubernetes_scheduler_test.py index 88983e95a..c81e7d086 100644 --- a/torchx/schedulers/test/kubernetes_scheduler_test.py +++ b/torchx/schedulers/test/kubernetes_scheduler_test.py @@ -320,6 +320,8 @@ def test_submit_dryrun(self) -> None: fieldPath: bar - name: TORCHX_RANK0_HOST value: localhost + - name: TORCHX_IMAGE + value: pytorch/torchx:latest image: pytorch/torchx:latest name: trainerfoo-0 ports: @@ -521,6 +523,9 @@ def test_rank0_env(self) -> None: self.assertIn( V1EnvVar(name="TORCHX_RANK0_HOST", value="localhost"), container0.env ) + self.assertIn( + V1EnvVar(name="TORCHX_IMAGE", value="pytorch/torchx:latest"), container0.env + ) container1 = tasks[1]["template"].spec.containers[0] self.assertIn("VC_TRAINERFOO_0_HOSTS", container1.command)