Skip to content

Commit 3b5df3a

Browse files
clumsyazzhipa
andauthored
feat: add TORCHX_IMAGE to env vars for Docker-based schedulers (#1128) (#1129)
Co-authored-by: Alexander Zhipa <[email protected]>
1 parent fd5c7d4 commit 3b5df3a

File tree

6 files changed

+25
-1
lines changed

6 files changed

+25
-1
lines changed

torchx/schedulers/aws_batch_scheduler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@
9292

9393
ENV_TORCHX_ROLE_NAME = "TORCHX_ROLE_NAME"
9494

95+
ENV_TORCHX_IMAGE = "TORCHX_IMAGE"
96+
9597
DEFAULT_ROLE_NAME = "node"
9698

9799
TAG_TORCHX_VER = "torchx.pytorch.org/version"
@@ -540,6 +542,7 @@ def _submit_dryrun(self, app: AppDef, cfg: AWSBatchOpts) -> AppDryRunInfo[BatchJ
540542
role = values.apply(role)
541543
role.env[ENV_TORCHX_ROLE_IDX] = str(role_idx)
542544
role.env[ENV_TORCHX_ROLE_NAME] = str(role.name)
545+
role.env[ENV_TORCHX_IMAGE] = role.image
543546

544547
nodes.append(
545548
_role_to_node_properties(

torchx/schedulers/docker_scheduler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def __repr__(self) -> str:
8484
LABEL_ROLE_NAME: str = "torchx.pytorch.org/role-name"
8585
LABEL_REPLICA_ID: str = "torchx.pytorch.org/replica-id"
8686

87+
ENV_TORCHX_IMAGE: str = "TORCHX_IMAGE"
88+
8789
NETWORK = "torchx"
8890

8991

@@ -279,6 +281,7 @@ def _submit_dryrun(self, app: AppDef, cfg: DockerOpts) -> AppDryRunInfo[DockerJo
279281

280282
# configure distributed host envs
281283
env["TORCHX_RANK0_HOST"] = rank0_name
284+
env[ENV_TORCHX_IMAGE] = replica_role.image
282285

283286
c = DockerContainer(
284287
image=replica_role.image,

torchx/schedulers/kubernetes_scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ def app_to_resource(
399399
replica_role = values.apply(role)
400400
if role_idx == 0 and replica_id == 0:
401401
replica_role.env["TORCHX_RANK0_HOST"] = "localhost"
402+
replica_role.env["TORCHX_IMAGE"] = replica_role.image
402403

403404
pod = role_to_pod(name, replica_role, service_account)
404405
pod.metadata.labels.update(

torchx/schedulers/test/aws_batch_scheduler_test.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
AWSBatchOpts,
2323
AWSBatchScheduler,
2424
create_scheduler,
25+
ENV_TORCHX_IMAGE,
2526
ENV_TORCHX_ROLE_NAME,
2627
parse_ulimits,
2728
resource_from_resource_requirements,
@@ -253,6 +254,10 @@ def test_submit_dryrun(self) -> None:
253254
{"name": "FOO", "value": "bar"},
254255
{"name": "TORCHX_ROLE_IDX", "value": "0"},
255256
{"name": "TORCHX_ROLE_NAME", "value": "trainer"},
257+
{
258+
"name": "TORCHX_IMAGE",
259+
"value": "pytorch/torchx:latest",
260+
},
256261
],
257262
"privileged": False,
258263
"resourceRequirements": [
@@ -494,7 +499,11 @@ def _mock_scheduler_running_job(self) -> AWSBatchScheduler:
494499
{
495500
"name": ENV_TORCHX_ROLE_NAME,
496501
"value": "echo",
497-
}
502+
},
503+
{
504+
"name": ENV_TORCHX_IMAGE,
505+
"value": "pytorch/torchx:latest",
506+
},
498507
],
499508
},
500509
}

torchx/schedulers/test/docker_scheduler_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def test_submit_dryrun(self) -> None:
100100
"environment": {
101101
"FOO": "bar",
102102
"TORCHX_RANK0_HOST": "app_name_42-trainer-0",
103+
"TORCHX_IMAGE": "pytorch/torchx:latest",
103104
},
104105
"labels": {
105106
"torchx.pytorch.org/app-id": "app_name_42",
@@ -190,6 +191,7 @@ def test_copy_env(self) -> None:
190191
"FOO_1": "f1",
191192
"BAR_1": "b1",
192193
"TORCHX_RANK0_HOST": "app_name_42-trainer-0",
194+
"TORCHX_IMAGE": "pytorch/torchx:latest",
193195
},
194196
)
195197

@@ -205,6 +207,7 @@ def test_env(self) -> None:
205207
"FOO": "bar",
206208
"FOO_1": "BAR_1",
207209
"TORCHX_RANK0_HOST": "app_name_42-trainer-0",
210+
"TORCHX_IMAGE": "pytorch/torchx:latest",
208211
},
209212
)
210213

torchx/schedulers/test/kubernetes_scheduler_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,8 @@ def test_submit_dryrun(self) -> None:
320320
fieldPath: bar
321321
- name: TORCHX_RANK0_HOST
322322
value: localhost
323+
- name: TORCHX_IMAGE
324+
value: pytorch/torchx:latest
323325
image: pytorch/torchx:latest
324326
name: trainerfoo-0
325327
ports:
@@ -521,6 +523,9 @@ def test_rank0_env(self) -> None:
521523
self.assertIn(
522524
V1EnvVar(name="TORCHX_RANK0_HOST", value="localhost"), container0.env
523525
)
526+
self.assertIn(
527+
V1EnvVar(name="TORCHX_IMAGE", value="pytorch/torchx:latest"), container0.env
528+
)
524529
container1 = tasks[1]["template"].spec.containers[0]
525530
self.assertIn("VC_TRAINERFOO_0_HOSTS", container1.command)
526531

0 commit comments

Comments
 (0)