Skip to content

Commit ebaa398

Browse files
authored
(torchx/schedulers) Remove redundantly setting dryruninfo._app and _cfg in k8s, gcp_batch, aws_batch, and docker scheduler subclasses since the scheduler abstract class sets it in the dryrun() API (#674)
1 parent a42a36d commit ebaa398

7 files changed

+18
-35
lines changed

torchx/schedulers/aws_batch_scheduler.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -442,11 +442,7 @@ def _submit_dryrun(self, app: AppDef, cfg: AWSBatchOpts) -> AppDryRunInfo[BatchJ
442442
job_def=job_def,
443443
images_to_push=images_to_push,
444444
)
445-
info = AppDryRunInfo(req, repr)
446-
info._app = app
447-
# pyre-fixme: AppDryRunInfo
448-
info._cfg = cfg
449-
return info
445+
return AppDryRunInfo(req, repr)
450446

451447
def _validate(self, app: AppDef, scheduler: str) -> None:
452448
# Skip validation step

torchx/schedulers/docker_scheduler.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -302,11 +302,7 @@ def _submit_dryrun(self, app: AppDef, cfg: DockerOpts) -> AppDryRunInfo[DockerJo
302302
]
303303
req.containers.append(c)
304304

305-
info = AppDryRunInfo(req, repr)
306-
info._app = app
307-
# pyre-fixme: AppDryRunInfo
308-
info._cfg = cfg
309-
return info
305+
return AppDryRunInfo(req, repr)
310306

311307
def _validate(self, app: AppDef, scheduler: str) -> None:
312308
# Skip validation step

torchx/schedulers/gcp_batch_scheduler.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -313,11 +313,7 @@ def _submit_dryrun(
313313
job_def=job,
314314
)
315315

316-
info = AppDryRunInfo(req, repr)
317-
info._app = app
318-
# pyre-fixme: AppDryRunInfo
319-
info._cfg = cfg
320-
return info
316+
return AppDryRunInfo(req, repr)
321317

322318
def run_opts(self) -> runopts:
323319
opts = runopts()

torchx/schedulers/kubernetes_scheduler.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -615,11 +615,7 @@ def _submit_dryrun(
615615
resource=resource,
616616
images_to_push=images_to_push,
617617
)
618-
info = AppDryRunInfo(req, repr)
619-
info._app = app
620-
# pyre-fixme: AppDryRunInfo
621-
info._cfg = cfg
622-
return info
618+
return AppDryRunInfo(req, repr)
623619

624620
def _validate(self, app: AppDef, scheduler: str) -> None:
625621
# Skip validation step

torchx/schedulers/test/gcp_batch_scheduler_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def test_submit_dryrun(self) -> None:
6969
proj = "test-proj"
7070
loc = "us-west-1"
7171
cfg = GCPBatchOpts(project=proj, location=loc)
72-
info = scheduler._submit_dryrun(app, cfg)
72+
info = scheduler.submit_dryrun(app, cfg)
7373

7474
req = info.request
7575
self.assertEqual(req.project, proj)
@@ -149,7 +149,7 @@ def test_submit_dryrun_throws(self) -> None:
149149
app.roles[0].resource.gpu = 3
150150
cfg = GCPBatchOpts(project="test-proj", location="us-west-1")
151151
with self.assertRaises(ValueError):
152-
scheduler._submit_dryrun(app, cfg)
152+
scheduler.submit_dryrun(app, cfg)
153153

154154
def test_app_id_to_job_full_name(self) -> None:
155155
scheduler = create_scheduler("test")
@@ -369,7 +369,7 @@ def test_submit(self) -> None:
369369
# pyre-fixme: GCPBatchOpts type passed to resolve
370370
resolved_cfg = scheduler.run_opts().resolve(cfg)
371371
# pyre-fixme: _submit_dryrun expects GCPBatchOpts
372-
info = scheduler._submit_dryrun(app, resolved_cfg)
372+
info = scheduler.submit_dryrun(app, resolved_cfg)
373373
id = scheduler.schedule(info)
374374
self.assertEqual(id, "test-proj:us-central1:app-name-42")
375375
self.assertEqual(scheduler._client.create_job.call_count, 1)

torchx/schedulers/test/kubernetes_scheduler_test.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def test_submit_dryrun(self) -> None:
213213
"torchx.schedulers.kubernetes_scheduler.make_unique"
214214
) as make_unique_ctx:
215215
make_unique_ctx.return_value = "app-name-42"
216-
info = scheduler._submit_dryrun(app, cfg)
216+
info = scheduler.submit_dryrun(app, cfg)
217217

218218
resource = str(info.request)
219219

@@ -460,9 +460,8 @@ def test_rank0_env(self) -> None:
460460
"torchx.schedulers.kubernetes_scheduler.make_unique"
461461
) as make_unique_ctx:
462462
make_unique_ctx.return_value = "app-name-42"
463-
info = scheduler._submit_dryrun(app, cfg)
463+
info = scheduler.submit_dryrun(app, cfg)
464464

465-
# pyre-fixme[16]; `object` has no attribute `__getitem__`.
466465
tasks = info.request.resource["spec"]["tasks"]
467466
container0 = tasks[0]["template"].spec.containers[0]
468467
self.assertIn("TORCHX_RANK0_HOST", container0.command)
@@ -486,7 +485,7 @@ def test_submit_dryrun_patch(self) -> None:
486485
"torchx.schedulers.kubernetes_scheduler.make_unique"
487486
) as make_unique_ctx:
488487
make_unique_ctx.return_value = "app-name-42"
489-
info = scheduler._submit_dryrun(app, cfg)
488+
info = scheduler.submit_dryrun(app, cfg)
490489

491490
self.assertIn("example.com/some/repo:testhash", str(info.request.resource))
492491
self.assertEqual(
@@ -509,11 +508,11 @@ def test_submit_dryrun_service_account(self) -> None:
509508
"service_account": "srvacc",
510509
}
511510
)
512-
info = scheduler._submit_dryrun(app, cfg)
511+
info = scheduler.submit_dryrun(app, cfg)
513512
self.assertIn("'service_account_name': 'srvacc'", str(info.request.resource))
514513

515514
del cfg["service_account"]
516-
info = scheduler._submit_dryrun(app, cfg)
515+
info = scheduler.submit_dryrun(app, cfg)
517516
self.assertIn("service_account_name': None", str(info.request.resource))
518517

519518
def test_submit_dryrun_priority_class(self) -> None:
@@ -527,11 +526,11 @@ def test_submit_dryrun_priority_class(self) -> None:
527526
}
528527
)
529528

530-
info = scheduler._submit_dryrun(app, cfg)
529+
info = scheduler.submit_dryrun(app, cfg)
531530
self.assertIn("'priorityClassName': 'high'", str(info.request.resource))
532531

533532
del cfg["priority_class"]
534-
info = scheduler._submit_dryrun(app, cfg)
533+
info = scheduler.submit_dryrun(app, cfg)
535534
self.assertNotIn("'priorityClassName'", str(info.request.resource))
536535

537536
@patch("kubernetes.client.CustomObjectsApi.create_namespaced_custom_object")
@@ -548,7 +547,7 @@ def test_submit(self, create_namespaced_custom_object: MagicMock) -> None:
548547
}
549548
)
550549

551-
info = scheduler._submit_dryrun(app, cfg)
550+
info = scheduler.submit_dryrun(app, cfg)
552551
id = scheduler.schedule(info)
553552
self.assertEqual(id, "testnamespace:testid")
554553
call = create_namespaced_custom_object.call_args
@@ -577,7 +576,7 @@ def test_submit_job_name_conflict(
577576
"queue": "testqueue",
578577
}
579578
)
580-
info = scheduler._submit_dryrun(app, cfg)
579+
info = scheduler.submit_dryrun(app, cfg)
581580
with self.assertRaises(ValueError):
582581
scheduler.schedule(info)
583582

@@ -895,4 +894,4 @@ def test_dryrun(self) -> None:
895894
)
896895

897896
with self.assertRaises(ModuleNotFoundError):
898-
scheduler._submit_dryrun(app, cfg)
897+
scheduler.submit_dryrun(app, cfg)

torchx/specs/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,7 @@ def __init__(self, request: T, fmt: Callable[[T], str]) -> None:
657657
# Scheduler or Session implementations
658658
# and are back references to the parameters
659659
# to dryrun() that returned this AppDryRunInfo object
660-
# thus they are set in Session.dryrun() and Scheduler.submit_dryrun()
660+
# thus they are set in Runner.dryrun() and Scheduler.submit_dryrun()
661661
# manually rather than through constructor arguments
662662
# DO NOT create getters or make these public
663663
# unless there is a good reason to

0 commit comments

Comments
 (0)