Skip to content

Commit f396d68

Browse files
committed
fix: test
1 parent 874ede4 commit f396d68

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

src/codeflare_sdk/ray/rayjobs/rayjob.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def __init__(
5454
shutdown_after_job_finishes: Optional[bool] = None,
5555
ttl_seconds_after_finished: int = 0,
5656
active_deadline_seconds: Optional[int] = None,
57+
entrypoint_num_cpus: Optional[int] = None,
58+
entrypoint_num_gpus: Optional[int] = None,
5759
):
5860
"""
5961
Initialize a RayJob instance.
@@ -100,6 +102,8 @@ def __init__(
100102
self.runtime_env = runtime_env
101103
self.ttl_seconds_after_finished = ttl_seconds_after_finished
102104
self.active_deadline_seconds = active_deadline_seconds
105+
self.entrypoint_num_cpus = entrypoint_num_cpus
106+
self.entrypoint_num_gpus = entrypoint_num_gpus
103107

104108
# Auto-set shutdown_after_job_finishes based on cluster_config presence
105109
# If cluster_config is provided, we want to clean up the cluster after job finishes
@@ -189,6 +193,16 @@ def _build_rayjob_cr(self) -> Dict[str, Any]:
189193
if self.active_deadline_seconds:
190194
rayjob_cr["spec"]["activeDeadlineSeconds"] = self.active_deadline_seconds
191195

196+
# Add entrypoint resource requirements if specified
197+
entrypoint_resources = {}
198+
if self.entrypoint_num_cpus is not None:
199+
entrypoint_resources["cpu"] = str(self.entrypoint_num_cpus)
200+
if self.entrypoint_num_gpus is not None:
201+
entrypoint_resources["gpu"] = str(self.entrypoint_num_gpus)
202+
203+
if entrypoint_resources:
204+
rayjob_cr["spec"]["entrypointResources"] = entrypoint_resources
205+
192206
# Add runtime environment if specified
193207
if self.runtime_env:
194208
rayjob_cr["spec"]["runtimeEnvYAML"] = str(self.runtime_env)

tests/e2e/rayjob_existing_cluster_kind_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def assert_rayjob_submit_against_existing_cluster(
9696
"env_vars": get_setup_env_variables(ACCELERATOR=accelerator),
9797
},
9898
shutdown_after_job_finishes=False,
99-
entrypoint_num_gpus=number_of_gpus,
99+
entrypoint_num_gpus=number_of_gpus if number_of_gpus > 0 else None,
100100
)
101101

102102
# Submit the job

0 commit comments

Comments
 (0)