Skip to content

Commit fd5c7d4

Browse files
clumsyazzhipa
andauthored
feat: add ulimits support to aws_batch (#1126) (#1127)
Co-authored-by: Alexander Zhipa <[email protected]>
1 parent 195419b commit fd5c7d4

File tree

2 files changed

+82
-4
lines changed

2 files changed

+82
-4
lines changed

torchx/schedulers/aws_batch_scheduler.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,37 @@
9999
TAG_TORCHX_USER = "torchx.pytorch.org/user"
100100

101101

102+
def parse_ulimits(ulimits_list: list[str]) -> List[Dict[str, Any]]:
103+
"""
104+
Parse ulimit string in format: name:softLimit:hardLimit
105+
Multiple ulimits separated by commas.
106+
"""
107+
if not ulimits_list:
108+
return []
109+
110+
ulimits = []
111+
for ulimit_str in ulimits_list:
112+
if not ulimit_str.strip():
113+
continue
114+
115+
parts = ulimit_str.strip().split(":")
116+
if len(parts) != 3:
117+
raise ValueError(
118+
f"ulimit must be in format name:softLimit:hardLimit, got: {ulimit_str}"
119+
)
120+
121+
name, soft_limit, hard_limit = parts
122+
ulimits.append(
123+
{
124+
"name": name,
125+
"softLimit": int(soft_limit) if soft_limit != "-1" else -1,
126+
"hardLimit": int(hard_limit) if hard_limit != "-1" else -1,
127+
}
128+
)
129+
130+
return ulimits
131+
132+
102133
if TYPE_CHECKING:
103134
from docker import DockerClient
104135

@@ -177,7 +208,8 @@ def _role_to_node_properties(
177208
privileged: bool = False,
178209
job_role_arn: Optional[str] = None,
179210
execution_role_arn: Optional[str] = None,
180-
) -> Dict[str, object]:
211+
ulimits: Optional[List[Dict[str, Any]]] = None,
212+
) -> Dict[str, Any]:
181213
role.mounts += get_device_mounts(role.resource.devices)
182214

183215
mount_points = []
@@ -239,6 +271,7 @@ def _role_to_node_properties(
239271
"environment": [{"name": k, "value": v} for k, v in role.env.items()],
240272
"privileged": privileged,
241273
"resourceRequirements": resource_requirements_from_resource(role.resource),
274+
**({"ulimits": ulimits} if ulimits else {}),
242275
"linuxParameters": {
243276
# To support PyTorch dataloaders we need to set /dev/shm to larger
244277
# than the 64M default.
@@ -361,6 +394,7 @@ class AWSBatchOpts(TypedDict, total=False):
361394
priority: int
362395
job_role_arn: Optional[str]
363396
execution_role_arn: Optional[str]
397+
ulimits: Optional[list[str]]
364398

365399

366400
class AWSBatchScheduler(
@@ -514,6 +548,7 @@ def _submit_dryrun(self, app: AppDef, cfg: AWSBatchOpts) -> AppDryRunInfo[BatchJ
514548
privileged=cfg["privileged"],
515549
job_role_arn=cfg.get("job_role_arn"),
516550
execution_role_arn=cfg.get("execution_role_arn"),
551+
ulimits=parse_ulimits(cfg.get("ulimits") or []),
517552
)
518553
)
519554
node_idx += role.num_replicas
@@ -599,6 +634,11 @@ def _run_opts(self) -> runopts:
599634
type_=str,
600635
help="The Amazon Resource Name (ARN) of the IAM role that the ECS agent can assume for AWS permissions.",
601636
)
637+
opts.add(
638+
"ulimits",
639+
type_=List[str],
640+
help="Ulimit settings in format: name:softLimit:hardLimit (multiple separated by commas)",
641+
)
602642
return opts
603643

604644
def _get_job_id(self, app_id: str) -> Optional[str]:

torchx/schedulers/test/aws_batch_scheduler_test.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
AWSBatchScheduler,
2424
create_scheduler,
2525
ENV_TORCHX_ROLE_NAME,
26+
parse_ulimits,
2627
resource_from_resource_requirements,
2728
resource_requirements_from_resource,
2829
to_millis_since_epoch,
@@ -311,7 +312,6 @@ def test_volume_mounts(self) -> None:
311312
)
312313
props = _role_to_node_properties(role, 0)
313314
self.assertEqual(
314-
# pyre-fixme[16]: `object` has no attribute `__getitem__`.
315315
props["container"]["volumes"],
316316
[
317317
{
@@ -350,7 +350,6 @@ def test_device_mounts(self) -> None:
350350
)
351351
props = _role_to_node_properties(role, 0)
352352
self.assertEqual(
353-
# pyre-fixme[16]: `object` has no attribute `__getitem__`.
354353
props["container"]["linuxParameters"]["devices"],
355354
[
356355
{
@@ -375,7 +374,6 @@ def test_resource_devices(self) -> None:
375374
)
376375
props = _role_to_node_properties(role, 0)
377376
self.assertEqual(
378-
# pyre-fixme[16]: `object` has no attribute `__getitem__`.
379377
props["container"]["linuxParameters"]["devices"],
380378
[
381379
{
@@ -396,6 +394,46 @@ def test_resource_devices(self) -> None:
396394
],
397395
)
398396

397+
def test_role_to_node_properties_ulimits(self) -> None:
398+
role = specs.Role(
399+
name="test",
400+
image="test:latest",
401+
entrypoint="test",
402+
args=["test"],
403+
resource=specs.Resource(cpu=1, memMB=1000, gpu=0),
404+
)
405+
ulimits = [
406+
{"name": "nofile", "softLimit": 65536, "hardLimit": 65536},
407+
{"name": "memlock", "softLimit": -1, "hardLimit": -1},
408+
]
409+
props = _role_to_node_properties(role, 0, ulimits=ulimits)
410+
self.assertEqual(
411+
props["container"]["ulimits"],
412+
ulimits,
413+
)
414+
415+
def test_parse_ulimits(self) -> None:
416+
# Test single ulimit
417+
result = parse_ulimits(["nofile:65536:65536"])
418+
expected = [{"name": "nofile", "softLimit": 65536, "hardLimit": 65536}]
419+
self.assertEqual(result, expected)
420+
421+
# Test multiple ulimits
422+
result = parse_ulimits(["nofile:65536:65536", "memlock:-1:-1"])
423+
expected = [
424+
{"name": "nofile", "softLimit": 65536, "hardLimit": 65536},
425+
{"name": "memlock", "softLimit": -1, "hardLimit": -1},
426+
]
427+
self.assertEqual(result, expected)
428+
429+
# Test empty list
430+
result = parse_ulimits([])
431+
self.assertEqual(result, [])
432+
433+
# Test invalid format
434+
with self.assertRaises(ValueError):
435+
parse_ulimits(["invalid"])
436+
399437
def _mock_scheduler_running_job(self) -> AWSBatchScheduler:
400438
scheduler = AWSBatchScheduler(
401439
"test",

0 commit comments

Comments
 (0)