Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torchx/schedulers/aws_batch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@

ENV_TORCHX_ROLE_NAME = "TORCHX_ROLE_NAME"

ENV_TORCHX_IMAGE = "TORCHX_IMAGE"

DEFAULT_ROLE_NAME = "node"

TAG_TORCHX_VER = "torchx.pytorch.org/version"
Expand Down Expand Up @@ -506,6 +508,7 @@ def _submit_dryrun(self, app: AppDef, cfg: AWSBatchOpts) -> AppDryRunInfo[BatchJ
role = values.apply(role)
role.env[ENV_TORCHX_ROLE_IDX] = str(role_idx)
role.env[ENV_TORCHX_ROLE_NAME] = str(role.name)
role.env[ENV_TORCHX_IMAGE] = role.image

nodes.append(
_role_to_node_properties(
Expand Down
3 changes: 3 additions & 0 deletions torchx/schedulers/docker_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def __repr__(self) -> str:
LABEL_ROLE_NAME: str = "torchx.pytorch.org/role-name"
LABEL_REPLICA_ID: str = "torchx.pytorch.org/replica-id"

ENV_TORCHX_IMAGE: str = "TORCHX_IMAGE"

NETWORK = "torchx"


Expand Down Expand Up @@ -279,6 +281,7 @@ def _submit_dryrun(self, app: AppDef, cfg: DockerOpts) -> AppDryRunInfo[DockerJo

# configure distributed host envs
env["TORCHX_RANK0_HOST"] = rank0_name
env[ENV_TORCHX_IMAGE] = replica_role.image

c = DockerContainer(
image=replica_role.image,
Expand Down
1 change: 1 addition & 0 deletions torchx/schedulers/kubernetes_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ def app_to_resource(
replica_role = values.apply(role)
if role_idx == 0 and replica_id == 0:
replica_role.env["TORCHX_RANK0_HOST"] = "localhost"
replica_role.env["TORCHX_IMAGE"] = replica_role.image

pod = role_to_pod(name, replica_role, service_account)
pod.metadata.labels.update(
Expand Down
11 changes: 10 additions & 1 deletion torchx/schedulers/test/aws_batch_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
AWSBatchOpts,
AWSBatchScheduler,
create_scheduler,
ENV_TORCHX_IMAGE,
ENV_TORCHX_ROLE_NAME,
resource_from_resource_requirements,
resource_requirements_from_resource,
Expand Down Expand Up @@ -252,6 +253,10 @@ def test_submit_dryrun(self) -> None:
{"name": "FOO", "value": "bar"},
{"name": "TORCHX_ROLE_IDX", "value": "0"},
{"name": "TORCHX_ROLE_NAME", "value": "trainer"},
{
"name": "TORCHX_IMAGE",
"value": "pytorch/torchx:latest",
},
],
"privileged": False,
"resourceRequirements": [
Expand Down Expand Up @@ -456,7 +461,11 @@ def _mock_scheduler_running_job(self) -> AWSBatchScheduler:
{
"name": ENV_TORCHX_ROLE_NAME,
"value": "echo",
}
},
{
"name": ENV_TORCHX_IMAGE,
"value": "pytorch/torchx:latest",
},
],
},
}
Expand Down
3 changes: 3 additions & 0 deletions torchx/schedulers/test/docker_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def test_submit_dryrun(self) -> None:
"environment": {
"FOO": "bar",
"TORCHX_RANK0_HOST": "app_name_42-trainer-0",
"TORCHX_IMAGE": "pytorch/torchx:latest",
},
"labels": {
"torchx.pytorch.org/app-id": "app_name_42",
Expand Down Expand Up @@ -190,6 +191,7 @@ def test_copy_env(self) -> None:
"FOO_1": "f1",
"BAR_1": "b1",
"TORCHX_RANK0_HOST": "app_name_42-trainer-0",
"TORCHX_IMAGE": "pytorch/torchx:latest",
},
)

Expand All @@ -205,6 +207,7 @@ def test_env(self) -> None:
"FOO": "bar",
"FOO_1": "BAR_1",
"TORCHX_RANK0_HOST": "app_name_42-trainer-0",
"TORCHX_IMAGE": "pytorch/torchx:latest",
},
)

Expand Down
5 changes: 5 additions & 0 deletions torchx/schedulers/test/kubernetes_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ def test_submit_dryrun(self) -> None:
fieldPath: bar
- name: TORCHX_RANK0_HOST
value: localhost
- name: TORCHX_IMAGE
value: pytorch/torchx:latest
image: pytorch/torchx:latest
name: trainerfoo-0
ports:
Expand Down Expand Up @@ -521,6 +523,9 @@ def test_rank0_env(self) -> None:
self.assertIn(
V1EnvVar(name="TORCHX_RANK0_HOST", value="localhost"), container0.env
)
self.assertIn(
V1EnvVar(name="TORCHX_IMAGE", value="pytorch/torchx:latest"), container0.env
)
container1 = tasks[1]["template"].spec.containers[0]
self.assertIn("VC_TRAINERFOO_0_HOSTS", container1.command)

Expand Down
Loading