Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 7a3d5f5

Browse files
committedJul 9, 2024·
simplify function calls and add option for custom resources
Signed-off-by: Kevin <[email protected]>
1 parent 2a85469 commit 7a3d5f5

19 files changed

+289
-205
lines changed
 

‎src/codeflare_sdk/cluster/cluster.py

Lines changed: 42 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from ..utils import pretty_print
3030
from ..utils.generate_yaml import (
3131
generate_appwrapper,
32+
head_worker_gpu_count_from_cluster,
3233
)
3334
from ..utils.kube_api_helpers import _kube_api_error_handling
3435
from ..utils.generate_yaml import is_openshift_cluster
@@ -118,48 +119,7 @@ def create_app_wrapper(self):
118119
f"Namespace {self.config.namespace} is of type {type(self.config.namespace)}. Check your Kubernetes Authentication."
119120
)
120121

121-
# Before attempting to create the cluster AW, let's evaluate the ClusterConfig
122-
123-
name = self.config.name
124-
namespace = self.config.namespace
125-
head_cpus = self.config.head_cpus
126-
head_memory = self.config.head_memory
127-
num_head_gpus = self.config.num_head_gpus
128-
worker_cpu_requests = self.config.worker_cpu_requests
129-
worker_cpu_limits = self.config.worker_cpu_limits
130-
worker_memory_requests = self.config.worker_memory_requests
131-
worker_memory_limits = self.config.worker_memory_limits
132-
num_worker_gpus = self.config.num_worker_gpus
133-
workers = self.config.num_workers
134-
template = self.config.template
135-
image = self.config.image
136-
appwrapper = self.config.appwrapper
137-
env = self.config.envs
138-
image_pull_secrets = self.config.image_pull_secrets
139-
write_to_file = self.config.write_to_file
140-
local_queue = self.config.local_queue
141-
labels = self.config.labels
142-
return generate_appwrapper(
143-
name=name,
144-
namespace=namespace,
145-
head_cpus=head_cpus,
146-
head_memory=head_memory,
147-
num_head_gpus=num_head_gpus,
148-
worker_cpu_requests=worker_cpu_requests,
149-
worker_cpu_limits=worker_cpu_limits,
150-
worker_memory_requests=worker_memory_requests,
151-
worker_memory_limits=worker_memory_limits,
152-
num_worker_gpus=num_worker_gpus,
153-
workers=workers,
154-
template=template,
155-
image=image,
156-
appwrapper=appwrapper,
157-
env=env,
158-
image_pull_secrets=image_pull_secrets,
159-
write_to_file=write_to_file,
160-
local_queue=local_queue,
161-
labels=labels,
162-
)
122+
return generate_appwrapper(self)
163123

164124
# creates a new cluster with the provided or default spec
165125
def up(self):
@@ -305,7 +265,7 @@ def status(
305265

306266
if print_to_console:
307267
# overriding the number of gpus with requested
308-
cluster.worker_gpu = self.config.num_worker_gpus
268+
_, cluster.worker_gpu = head_worker_gpu_count_from_cluster(self)
309269
pretty_print.print_cluster_status(cluster)
310270
elif print_to_console:
311271
if status == CodeFlareClusterStatus.UNKNOWN:
@@ -443,6 +403,29 @@ def job_logs(self, job_id: str) -> str:
443403
"""
444404
return self.job_client.get_job_logs(job_id)
445405

406+
@staticmethod
407+
def _head_worker_extended_resources_from_rc_dict(rc: Dict) -> Tuple[dict, dict]:
408+
head_extended_resources, worker_extended_resources = {}, {}
409+
for resource in rc["spec"]["workerGroupSpecs"][0]["template"]["spec"][
410+
"containers"
411+
][0]["resources"]["limits"].keys():
412+
if resource in ["memory", "cpu"]:
413+
continue
414+
worker_extended_resources[resource] = rc["spec"]["workerGroupSpecs"][0][
415+
"template"
416+
]["spec"]["containers"][0]["resources"]["limits"][resource]
417+
418+
for resource in rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][
419+
0
420+
]["resources"]["limits"].keys():
421+
if resource in ["memory", "cpu"]:
422+
continue
423+
head_extended_resources[resource] = rc["spec"]["headGroupSpec"]["template"][
424+
"spec"
425+
]["containers"][0]["resources"]["limits"][resource]
426+
427+
return head_extended_resources, worker_extended_resources
428+
446429
def from_k8_cluster_object(
447430
rc,
448431
appwrapper=True,
@@ -456,6 +439,11 @@ def from_k8_cluster_object(
456439
else []
457440
)
458441

442+
(
443+
head_extended_resources,
444+
worker_extended_resources,
445+
) = Cluster._head_worker_extended_resources_from_rc_dict(rc)
446+
459447
cluster_config = ClusterConfiguration(
460448
name=rc["metadata"]["name"],
461449
namespace=rc["metadata"]["namespace"],
@@ -473,11 +461,8 @@ def from_k8_cluster_object(
473461
worker_memory_limits=rc["spec"]["workerGroupSpecs"][0]["template"]["spec"][
474462
"containers"
475463
][0]["resources"]["limits"]["memory"],
476-
num_worker_gpus=int(
477-
rc["spec"]["workerGroupSpecs"][0]["template"]["spec"]["containers"][0][
478-
"resources"
479-
]["limits"]["nvidia.com/gpu"]
480-
),
464+
worker_extended_resource_requests=worker_extended_resources,
465+
head_extended_resource_requests=head_extended_resources,
481466
image=rc["spec"]["workerGroupSpecs"][0]["template"]["spec"]["containers"][
482467
0
483468
]["image"],
@@ -858,6 +843,11 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
858843
protocol = "https"
859844
dashboard_url = f"{protocol}://{ingress.spec.rules[0].host}"
860845

846+
(
847+
head_extended_resources,
848+
worker_extended_resources,
849+
) = Cluster._head_worker_extended_resources_from_rc_dict(rc)
850+
861851
return RayCluster(
862852
name=rc["metadata"]["name"],
863853
status=status,
@@ -872,17 +862,15 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
872862
worker_cpu=rc["spec"]["workerGroupSpecs"][0]["template"]["spec"]["containers"][
873863
0
874864
]["resources"]["limits"]["cpu"],
875-
worker_gpu=0, # hard to detect currently how many gpus, can override it with what the user asked for
865+
worker_extended_resources=worker_extended_resources,
876866
namespace=rc["metadata"]["namespace"],
877867
head_cpus=rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][0][
878868
"resources"
879869
]["limits"]["cpu"],
880870
head_mem=rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][0][
881871
"resources"
882872
]["limits"]["memory"],
883-
head_gpu=rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][0][
884-
"resources"
885-
]["limits"]["nvidia.com/gpu"],
873+
head_extended_resources=head_extended_resources,
886874
dashboard=dashboard_url,
887875
)
888876

@@ -907,12 +895,12 @@ def _copy_to_ray(cluster: Cluster) -> RayCluster:
907895
worker_mem_min=cluster.config.worker_memory_requests,
908896
worker_mem_max=cluster.config.worker_memory_limits,
909897
worker_cpu=cluster.config.worker_cpu_requests,
910-
worker_gpu=cluster.config.num_worker_gpus,
898+
worker_extended_resources=cluster.config.worker_extended_resource_requests,
911899
namespace=cluster.config.namespace,
912900
dashboard=cluster.cluster_dashboard_uri(),
913901
head_cpus=cluster.config.head_cpus,
914902
head_mem=cluster.config.head_memory,
915-
head_gpu=cluster.config.num_head_gpus,
903+
head_extended_resources=cluster.config.head_extended_resource_requests,
916904
)
917905
if ray.status == CodeFlareClusterStatus.READY:
918906
ray.status = RayClusterStatus.READY

‎src/codeflare_sdk/cluster/config.py

Lines changed: 98 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,51 @@
2525

2626
dir = pathlib.Path(__file__).parent.parent.resolve()
2727

28+
# https://docs.ray.io/en/latest/ray-core/scheduling/accelerators.html
29+
DEFAULT_RESOURCE_MAPPING = {
30+
"nvidia.com/gpu": "GPU",
31+
"intel.com/gpu": "GPU",
32+
"amd.com/gpu": "GPU",
33+
"aws.amazon.com/neuroncore": "neuron_cores",
34+
"google.com/tpu": "TPU",
35+
"habana.ai/gaudi": "HPU",
36+
"huawei.com/Ascend910": "NPU",
37+
"huawei.com/Ascend310": "NPU",
38+
}
39+
2840

2941
@dataclass
3042
class ClusterConfiguration:
3143
"""
3244
This dataclass is used to specify resource requirements and other details, and
3345
is passed in as an argument when creating a Cluster object.
46+
47+
Attributes:
48+
- name: The name of the cluster.
49+
- namespace: The namespace in which the cluster should be created.
50+
- head_info: A list of strings containing information about the head node.
51+
- head_cpus: The number of CPUs to allocate to the head node.
52+
- head_memory: The amount of memory to allocate to the head node.
53+
- head_gpus: The number of GPUs to allocate to the head node. (Deprecated, use head_extended_resource_requests)
54+
- head_extended_resource_requests: A dictionary of extended resource requests for the head node. ex: {"nvidia.com/gpu": 1}
55+
- machine_types: A list of machine types to use for the cluster.
56+
- min_cpus: The minimum number of CPUs to allocate to each worker.
57+
- max_cpus: The maximum number of CPUs to allocate to each worker.
58+
- num_workers: The number of workers to create.
59+
- min_memory: The minimum amount of memory to allocate to each worker.
60+
- max_memory: The maximum amount of memory to allocate to each worker.
61+
- num_gpus: The number of GPUs to allocate to each worker. (Deprecated, use worker_extended_resource_requests)
62+
- template: The path to the template file to use for the cluster.
63+
- appwrapper: A boolean indicating whether to use an AppWrapper.
64+
- envs: A dictionary of environment variables to set for the cluster.
65+
- image: The image to use for the cluster.
66+
- image_pull_secrets: A list of image pull secrets to use for the cluster.
67+
- write_to_file: A boolean indicating whether to write the cluster configuration to a file.
68+
- verify_tls: A boolean indicating whether to verify TLS when connecting to the cluster.
69+
- labels: A dictionary of labels to apply to the cluster.
70+
- worker_extended_resource_requests: A dictionary of extended resource requests for each worker. ex: {"nvidia.com/gpu": 1}
71+
- extended_resource_mapping: A dictionary of custom resource mappings to map extended resource requests to RayCluster resource names
72+
- overwrite_default_resource_mapping: A boolean indicating whether to overwrite the default resource mapping.
3473
"""
3574

3675
name: str
@@ -39,7 +78,7 @@ class ClusterConfiguration:
3978
head_cpus: typing.Union[int, str] = 2
4079
head_memory: typing.Union[int, str] = 8
4180
head_gpus: int = None # Deprecating
42-
num_head_gpus: int = 0
81+
head_extended_resource_requests: typing.Dict[str, int] = field(default_factory=dict)
4382
machine_types: list = field(default_factory=list) # ["m4.xlarge", "g4dn.xlarge"]
4483
worker_cpu_requests: typing.Union[int, str] = 1
4584
worker_cpu_limits: typing.Union[int, str] = 1
@@ -50,7 +89,6 @@ class ClusterConfiguration:
5089
worker_memory_limits: typing.Union[int, str] = 2
5190
min_memory: typing.Union[int, str] = None # Deprecating
5291
max_memory: typing.Union[int, str] = None # Deprecating
53-
num_worker_gpus: int = 0
5492
num_gpus: int = None # Deprecating
5593
template: str = f"{dir}/templates/base-template.yaml"
5694
appwrapper: bool = False
@@ -60,6 +98,11 @@ class ClusterConfiguration:
6098
write_to_file: bool = False
6199
verify_tls: bool = True
62100
labels: dict = field(default_factory=dict)
101+
worker_extended_resource_requests: typing.Dict[str, int] = field(
102+
default_factory=dict
103+
)
104+
extended_resource_mapping: typing.Dict[str, str] = field(default_factory=dict)
105+
overwrite_default_resource_mapping: bool = False
63106

64107
def __post_init__(self):
65108
if not self.verify_tls:
@@ -70,8 +113,60 @@ def __post_init__(self):
70113
self._memory_to_string()
71114
self._str_mem_no_unit_add_GB()
72115
self._memory_to_resource()
73-
self._gpu_to_resource()
74116
self._cpu_to_resource()
117+
self._gpu_to_resource()
118+
self._combine_extended_resource_mapping()
119+
self._validate_extended_resource_requests(self.head_extended_resource_requests)
120+
self._validate_extended_resource_requests(
121+
self.worker_extended_resource_requests
122+
)
123+
124+
def _combine_extended_resource_mapping(self):
125+
if overwritten := set(self.extended_resource_mapping.keys()).intersection(
126+
DEFAULT_RESOURCE_MAPPING.keys()
127+
):
128+
if self.overwrite_default_resource_mapping:
129+
warnings.warn(
130+
f"Overwriting default resource mapping for {overwritten}",
131+
UserWarning,
132+
)
133+
else:
134+
raise ValueError(
135+
f"Resource mapping already exists for {overwritten}, set overwrite_default_resource_mapping to True to overwrite"
136+
)
137+
self.extended_resource_mapping = {
138+
**DEFAULT_RESOURCE_MAPPING,
139+
**self.extended_resource_mapping,
140+
}
141+
142+
def _validate_extended_resource_requests(
143+
self, extended_resources: typing.Dict[str, int]
144+
):
145+
for k in extended_resources.keys():
146+
if k not in self.extended_resource_mapping.keys():
147+
raise ValueError(
148+
f"extended resource '{k}' not found in extended_resource_mapping, available resources are {list(self.extended_resource_mapping.keys())}, to add more supported resources use extended_resource_mapping. i.e. extended_resource_mapping = {{'{k}': 'FOO_BAR'}}"
149+
)
150+
151+
def _gpu_to_resource(self):
152+
if self.head_gpus:
153+
warnings.warn(
154+
f"head_gpus is being deprecated, replacing with head_extended_resource_requests['nvidia.com/gpu'] = {self.head_gpus}"
155+
)
156+
if "nvidia.com/gpu" in self.head_extended_resource_requests:
157+
raise ValueError(
158+
"nvidia.com/gpu already exists in head_extended_resource_requests"
159+
)
160+
self.head_extended_resource_requests["nvidia.com/gpu"] = self.head_gpus
161+
if self.num_gpus:
162+
warnings.warn(
163+
f"num_gpus is being deprecated, replacing with worker_extended_resource_requests['nvidia.com/gpu'] = {self.num_gpus}"
164+
)
165+
if "nvidia.com/gpu" in self.worker_extended_resource_requests:
166+
raise ValueError(
167+
"nvidia.com/gpu already exists in worker_extended_resource_requests"
168+
)
169+
self.worker_extended_resource_requests["nvidia.com/gpu"] = self.num_gpus
75170

76171
def _str_mem_no_unit_add_GB(self):
77172
if isinstance(self.head_memory, str) and self.head_memory.isdecimal():
@@ -95,14 +190,6 @@ def _memory_to_string(self):
95190
if isinstance(self.worker_memory_limits, int):
96191
self.worker_memory_limits = f"{self.worker_memory_limits}G"
97192

98-
def _gpu_to_resource(self):
99-
if self.head_gpus:
100-
warnings.warn("head_gpus is being deprecated, use num_head_gpus")
101-
self.num_head_gpus = self.head_gpus
102-
if self.num_gpus:
103-
warnings.warn("num_gpus is being deprecated, use num_worker_gpus")
104-
self.num_worker_gpus = self.num_gpus
105-
106193
def _cpu_to_resource(self):
107194
if self.min_cpus:
108195
warnings.warn("min_cpus is being deprecated, use worker_cpu_requests")

‎src/codeflare_sdk/cluster/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
dataclasses to store information for Ray clusters and AppWrappers.
1919
"""
2020

21-
from dataclasses import dataclass
21+
from dataclasses import dataclass, field
2222
from enum import Enum
23+
import typing
2324

2425

2526
class RayClusterStatus(Enum):
@@ -74,14 +75,14 @@ class RayCluster:
7475
status: RayClusterStatus
7576
head_cpus: int
7677
head_mem: str
77-
head_gpu: int
7878
workers: int
7979
worker_mem_min: str
8080
worker_mem_max: str
8181
worker_cpu: int
82-
worker_gpu: int
8382
namespace: str
8483
dashboard: str
84+
worker_extended_resources: typing.Dict[str, int] = field(default_factory=dict)
85+
head_extended_resources: typing.Dict[str, int] = field(default_factory=dict)
8586

8687

8788
@dataclass

‎src/codeflare_sdk/templates/base-template.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,9 @@ spec:
8686
limits:
8787
cpu: 2
8888
memory: "8G"
89-
nvidia.com/gpu: 0
9089
requests:
9190
cpu: 2
9291
memory: "8G"
93-
nvidia.com/gpu: 0
9492
volumeMounts:
9593
- mountPath: /etc/pki/tls/certs/odh-trusted-ca-bundle.crt
9694
name: odh-trusted-ca-cert
@@ -163,11 +161,9 @@ spec:
163161
limits:
164162
cpu: "2"
165163
memory: "12G"
166-
nvidia.com/gpu: "1"
167164
requests:
168165
cpu: "2"
169166
memory: "12G"
170-
nvidia.com/gpu: "1"
171167
volumeMounts:
172168
- mountPath: /etc/pki/tls/certs/odh-trusted-ca-bundle.crt
173169
name: odh-trusted-ca-cert

‎src/codeflare_sdk/utils/generate_yaml.py

Lines changed: 99 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
(in the cluster sub-module) for AppWrapper generation.
1818
"""
1919

20+
import json
2021
from typing import Optional
2122
import typing
2223
import yaml
@@ -31,6 +32,7 @@
3132
from base64 import b64encode
3233
from urllib3.util import parse_url
3334
from kubernetes.client.exceptions import ApiException
35+
import codeflare_sdk
3436

3537

3638
def read_template(template):
@@ -78,10 +80,13 @@ def is_kind_cluster():
7880
return False
7981

8082

81-
def update_names(cluster_yaml, cluster_name, namespace):
82-
meta = cluster_yaml.get("metadata")
83-
meta["name"] = cluster_name
84-
meta["namespace"] = namespace
83+
def update_names(
84+
cluster_yaml: dict,
85+
cluster: "codeflare_sdk.cluster.Cluster",
86+
):
87+
metadata = cluster_yaml.get("metadata")
88+
metadata["name"] = cluster.config.name
89+
metadata["namespace"] = cluster.config.namespace
8590

8691

8792
def update_image(spec, image):
@@ -114,67 +119,111 @@ def update_resources(
114119
worker_cpu_limits,
115120
worker_memory_requests,
116121
worker_memory_limits,
117-
num_worker_gpus,
122+
custom_resources,
118123
):
119124
container = spec.get("containers")
120125
for resource in container:
121126
requests = resource.get("resources").get("requests")
122127
if requests is not None:
123128
requests["cpu"] = worker_cpu_requests
124129
requests["memory"] = worker_memory_requests
125-
requests["nvidia.com/gpu"] = num_worker_gpus
126130
limits = resource.get("resources").get("limits")
127131
if limits is not None:
128132
limits["cpu"] = worker_cpu_limits
129133
limits["memory"] = worker_memory_limits
130-
limits["nvidia.com/gpu"] = num_worker_gpus
134+
for k in custom_resources.keys():
135+
limits[k] = custom_resources[k]
136+
requests[k] = custom_resources[k]
137+
138+
139+
def head_worker_gpu_count_from_cluster(
140+
cluster: "codeflare_sdk.cluster.Cluster",
141+
) -> typing.Tuple[int, int]:
142+
head_gpus = 0
143+
worker_gpus = 0
144+
for k in cluster.config.head_extended_resource_requests.keys():
145+
resource_type = cluster.config.extended_resource_mapping[k]
146+
if resource_type == "GPU":
147+
head_gpus += int(cluster.config.head_extended_resource_requests[k])
148+
for k in cluster.config.worker_extended_resource_requests.keys():
149+
resource_type = cluster.config.extended_resource_mapping[k]
150+
if resource_type == "GPU":
151+
worker_gpus += int(cluster.config.worker_extended_resource_requests[k])
152+
153+
return head_gpus, worker_gpus
154+
155+
156+
FORBIDDEN_CUSTOM_RESOURCE_TYPES = ["GPU", "CPU", "memory"]
157+
158+
159+
def head_worker_resources_from_cluster(
160+
cluster: "codeflare_sdk.cluster.Cluster",
161+
) -> typing.Tuple[dict, dict]:
162+
to_return = {}, {}
163+
for k in cluster.config.head_extended_resource_requests.keys():
164+
resource_type = cluster.config.extended_resource_mapping[k]
165+
if resource_type in FORBIDDEN_CUSTOM_RESOURCE_TYPES:
166+
continue
167+
to_return[0][resource_type] = cluster.config.head_extended_resource_requests[
168+
k
169+
] + to_return[0].get(resource_type, 0)
170+
171+
for k in cluster.config.worker_extended_resource_requests.keys():
172+
resource_type = cluster.config.extended_resource_mapping[k]
173+
if resource_type in FORBIDDEN_CUSTOM_RESOURCE_TYPES:
174+
continue
175+
to_return[1][resource_type] = cluster.config.worker_extended_resource_requests[
176+
k
177+
] + to_return[1].get(resource_type, 0)
178+
return to_return
131179

132180

133181
def update_nodes(
134-
cluster_yaml,
135-
appwrapper_name,
136-
worker_cpu_requests,
137-
worker_cpu_limits,
138-
worker_memory_requests,
139-
worker_memory_limits,
140-
num_worker_gpus,
141-
workers,
142-
image,
143-
env,
144-
image_pull_secrets,
145-
head_cpus,
146-
head_memory,
147-
num_head_gpus,
182+
ray_cluster_dict: dict,
183+
cluster: "codeflare_sdk.cluster.Cluster",
148184
):
149-
head = cluster_yaml.get("spec").get("headGroupSpec")
150-
head["rayStartParams"]["num-gpus"] = str(int(num_head_gpus))
185+
head = ray_cluster_dict.get("spec").get("headGroupSpec")
186+
worker = ray_cluster_dict.get("spec").get("workerGroupSpecs")[0]
187+
head_gpus, worker_gpus = head_worker_gpu_count_from_cluster(cluster)
188+
head_resources, worker_resources = head_worker_resources_from_cluster(cluster)
189+
head_resources = json.dumps(head_resources).replace('"', '\\"')
190+
head_resources = f'"{head_resources}"'
191+
worker_resources = json.dumps(worker_resources).replace('"', '\\"')
192+
worker_resources = f'"{worker_resources}"'
193+
head["rayStartParams"]["num-gpus"] = str(head_gpus)
194+
head["rayStartParams"]["resources"] = head_resources
151195

152-
worker = cluster_yaml.get("spec").get("workerGroupSpecs")[0]
153196
# Head counts as first worker
154-
worker["replicas"] = workers
155-
worker["minReplicas"] = workers
156-
worker["maxReplicas"] = workers
157-
worker["groupName"] = "small-group-" + appwrapper_name
158-
worker["rayStartParams"]["num-gpus"] = str(int(num_worker_gpus))
197+
worker["replicas"] = cluster.config.num_workers
198+
worker["minReplicas"] = cluster.config.num_workers
199+
worker["maxReplicas"] = cluster.config.num_workers
200+
worker["groupName"] = "small-group-" + cluster.config.name
201+
worker["rayStartParams"]["num-gpus"] = str(worker_gpus)
202+
worker["rayStartParams"]["resources"] = worker_resources
159203

160204
for comp in [head, worker]:
161205
spec = comp.get("template").get("spec")
162-
update_image_pull_secrets(spec, image_pull_secrets)
163-
update_image(spec, image)
164-
update_env(spec, env)
206+
update_image_pull_secrets(spec, cluster.config.image_pull_secrets)
207+
update_image(spec, cluster.config.image)
208+
update_env(spec, cluster.config.envs)
165209
if comp == head:
166210
# TODO: Eventually add head node configuration outside of template
167211
update_resources(
168-
spec, head_cpus, head_cpus, head_memory, head_memory, num_head_gpus
212+
spec,
213+
cluster.config.head_cpus,
214+
cluster.config.head_cpus,
215+
cluster.config.head_memory,
216+
cluster.config.head_memory,
217+
cluster.config.head_extended_resource_requests,
169218
)
170219
else:
171220
update_resources(
172221
spec,
173-
worker_cpu_requests,
174-
worker_cpu_limits,
175-
worker_memory_requests,
176-
worker_memory_limits,
177-
num_worker_gpus,
222+
cluster.config.worker_cpu_requests,
223+
cluster.config.worker_cpu_limits,
224+
cluster.config.worker_memory_requests,
225+
cluster.config.worker_memory_limits,
226+
cluster.config.worker_extended_resource_requests,
178227
)
179228

180229

@@ -278,63 +327,30 @@ def write_user_yaml(user_yaml, output_file_name):
278327
print(f"Written to: {output_file_name}")
279328

280329

281-
def generate_appwrapper(
282-
name: str,
283-
namespace: str,
284-
head_cpus: int,
285-
head_memory: int,
286-
num_head_gpus: int,
287-
worker_cpu_requests: int,
288-
worker_cpu_limits: int,
289-
worker_memory_requests: int,
290-
worker_memory_limits: int,
291-
num_worker_gpus: int,
292-
workers: int,
293-
template: str,
294-
image: str,
295-
appwrapper: bool,
296-
env,
297-
image_pull_secrets: list,
298-
write_to_file: bool,
299-
local_queue: Optional[str],
300-
labels,
301-
):
302-
cluster_yaml = read_template(template)
303-
appwrapper_name, cluster_name = gen_names(name)
304-
update_names(cluster_yaml, cluster_name, namespace)
305-
update_nodes(
330+
def generate_appwrapper(cluster: "codeflare_sdk.cluster.Cluster"):
331+
cluster_yaml = read_template(cluster.config.template)
332+
appwrapper_name, _ = gen_names(cluster.config.name)
333+
update_names(
306334
cluster_yaml,
307-
appwrapper_name,
308-
worker_cpu_requests,
309-
worker_cpu_limits,
310-
worker_memory_requests,
311-
worker_memory_limits,
312-
num_worker_gpus,
313-
workers,
314-
image,
315-
env,
316-
image_pull_secrets,
317-
head_cpus,
318-
head_memory,
319-
num_head_gpus,
335+
cluster,
320336
)
321-
augment_labels(cluster_yaml, labels)
337+
update_nodes(cluster_yaml, cluster)
338+
augment_labels(cluster_yaml, cluster.config.labels)
322339
notebook_annotations(cluster_yaml)
323-
324340
user_yaml = (
325-
wrap_cluster(cluster_yaml, appwrapper_name, namespace)
326-
if appwrapper
341+
wrap_cluster(cluster_yaml, appwrapper_name, cluster.config.namespace)
342+
if cluster.config.appwrapper
327343
else cluster_yaml
328344
)
329345

330-
add_queue_label(user_yaml, namespace, local_queue)
346+
add_queue_label(user_yaml, cluster.config.namespace, cluster.config.local_queue)
331347

332-
if write_to_file:
348+
if cluster.config.write_to_file:
333349
directory_path = os.path.expanduser("~/.codeflare/resources/")
334350
outfile = os.path.join(directory_path, appwrapper_name + ".yaml")
335351
write_user_yaml(user_yaml, outfile)
336352
return outfile
337353
else:
338354
user_yaml = yaml.dump(user_yaml)
339-
print(f"Yaml resources loaded for {name}")
355+
print(f"Yaml resources loaded for {cluster.config.name}")
340356
return user_yaml

‎src/codeflare_sdk/utils/pretty_print.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def print_clusters(clusters: List[RayCluster]):
138138
workers = str(cluster.workers)
139139
memory = f"{cluster.worker_mem_min}~{cluster.worker_mem_max}"
140140
cpu = str(cluster.worker_cpu)
141-
gpu = str(cluster.worker_gpu)
141+
gpu = str(cluster.worker_extended_resources.get("nvidia.com/gpu", 0))
142142

143143
#'table0' to display the cluster name, status, url, and dashboard link
144144
table0 = Table(box=None, show_header=False)

‎tests/e2e/local_interactive_sdk_kind_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def run_local_interactives(self):
4343
worker_cpu_limits=1,
4444
worker_memory_requests=1,
4545
worker_memory_limits=2,
46-
num_worker_gpus=0,
4746
image=ray_image,
4847
write_to_file=True,
4948
verify_tls=False,

‎tests/e2e/local_interactive_sdk_oauth_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def run_local_interactives(self):
4848
worker_cpu_limits=1,
4949
worker_memory_requests=4,
5050
worker_memory_limits=4,
51-
num_worker_gpus=0,
5251
image=ray_image,
5352
verify_tls=False,
5453
)

‎tests/e2e/mnist_raycluster_sdk_kind_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def run_mnist_raycluster_sdk_kind(self):
4141
worker_cpu_limits=1,
4242
worker_memory_requests=1,
4343
worker_memory_limits=2,
44-
num_worker_gpus=0,
4544
image=ray_image,
4645
write_to_file=True,
4746
verify_tls=False,

‎tests/e2e/mnist_raycluster_sdk_oauth_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def run_mnist_raycluster_sdk_oauth(self):
4848
worker_cpu_limits=1,
4949
worker_memory_requests=1,
5050
worker_memory_limits=2,
51-
num_worker_gpus=0,
5251
image=ray_image,
5352
write_to_file=True,
5453
verify_tls=False,

‎tests/e2e/start_ray_cluster.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
worker_cpu_limits=1,
2020
worker_memory_requests=1,
2121
worker_memory_limits=2,
22-
num_worker_gpus=0,
2322
image=ray_image,
2423
appwrapper=True,
2524
)

‎tests/test-case-bad.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ spec:
3333
block: 'true'
3434
dashboard-host: 0.0.0.0
3535
num-gpus: '0'
36+
resources: '"{}"'
3637
serviceType: ClusterIP
3738
template:
3839
spec:
@@ -63,11 +64,9 @@ spec:
6364
limits:
6465
cpu: 2
6566
memory: 8G
66-
nvidia.com/gpu: 0
6767
requests:
6868
cpu: 2
6969
memory: 8G
70-
nvidia.com/gpu: 0
7170
rayVersion: 2.23.0
7271
workerGroupSpecs:
7372
- groupName: small-group-unit-test-cluster
@@ -76,6 +75,7 @@ spec:
7675
rayStartParams:
7776
block: 'true'
7877
num-gpus: '7'
78+
resources: '"{}"'
7979
replicas: 2
8080
template:
8181
metadata:

‎tests/test-case-no-kueue-no-aw.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ spec:
2626
block: 'true'
2727
dashboard-host: 0.0.0.0
2828
num-gpus: '0'
29+
resources: '"{}"'
2930
serviceType: ClusterIP
3031
template:
3132
spec:
@@ -51,11 +52,9 @@ spec:
5152
limits:
5253
cpu: 2
5354
memory: 8G
54-
nvidia.com/gpu: 0
5555
requests:
5656
cpu: 2
5757
memory: 8G
58-
nvidia.com/gpu: 0
5958
volumeMounts:
6059
- mountPath: /etc/pki/tls/certs/odh-trusted-ca-bundle.crt
6160
name: odh-trusted-ca-cert
@@ -94,6 +93,7 @@ spec:
9493
rayStartParams:
9594
block: 'true'
9695
num-gpus: '7'
96+
resources: '"{}"'
9797
replicas: 2
9898
template:
9999
metadata:

‎tests/test-case-no-mcad.yamls

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ spec:
2929
block: 'true'
3030
dashboard-host: 0.0.0.0
3131
num-gpus: '0'
32+
resources: '"{}"'
3233
serviceType: ClusterIP
3334
template:
3435
spec:
@@ -54,11 +55,9 @@ spec:
5455
limits:
5556
cpu: 2
5657
memory: 8G
57-
nvidia.com/gpu: 0
5858
requests:
5959
cpu: 2
6060
memory: 8G
61-
nvidia.com/gpu: 0
6261
volumeMounts:
6362
- mountPath: /etc/pki/tls/certs/odh-trusted-ca-bundle.crt
6463
name: odh-trusted-ca-cert
@@ -97,6 +96,7 @@ spec:
9796
rayStartParams:
9897
block: 'true'
9998
num-gpus: '7'
99+
resources: '"{}"'
100100
replicas: 2
101101
template:
102102
metadata:

‎tests/test-case.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ spec:
3434
block: 'true'
3535
dashboard-host: 0.0.0.0
3636
num-gpus: '0'
37+
resources: '"{}"'
3738
serviceType: ClusterIP
3839
template:
3940
spec:
@@ -59,11 +60,9 @@ spec:
5960
limits:
6061
cpu: 2
6162
memory: 8G
62-
nvidia.com/gpu: 0
6363
requests:
6464
cpu: 2
6565
memory: 8G
66-
nvidia.com/gpu: 0
6766
volumeMounts:
6867
- mountPath: /etc/pki/tls/certs/odh-trusted-ca-bundle.crt
6968
name: odh-trusted-ca-cert
@@ -102,6 +101,7 @@ spec:
102101
rayStartParams:
103102
block: 'true'
104103
num-gpus: '7'
104+
resources: '"{}"'
105105
replicas: 2
106106
template:
107107
metadata:

‎tests/test-default-appwrapper.yaml

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@ spec:
3434
block: 'true'
3535
dashboard-host: 0.0.0.0
3636
num-gpus: '0'
37+
resources: '"{}"'
3738
serviceType: ClusterIP
3839
template:
3940
spec:
41+
imagePullSecrets: []
4042
containers:
4143
- image: quay.io/rhoai/ray:2.23.0-py39-cu121
4244
imagePullPolicy: Always
@@ -59,11 +61,9 @@ spec:
5961
limits:
6062
cpu: 2
6163
memory: 8G
62-
nvidia.com/gpu: 0
6364
requests:
6465
cpu: 2
6566
memory: 8G
66-
nvidia.com/gpu: 0
6767
volumeMounts:
6868
- mountPath: /etc/pki/tls/certs/odh-trusted-ca-bundle.crt
6969
name: odh-trusted-ca-cert
@@ -77,7 +77,6 @@ spec:
7777
- mountPath: /etc/ssl/certs/odh-ca-bundle.crt
7878
name: odh-ca-cert
7979
subPath: odh-ca-bundle.crt
80-
imagePullSecrets: []
8180
volumes:
8281
- configMap:
8382
items:
@@ -101,6 +100,7 @@ spec:
101100
rayStartParams:
102101
block: 'true'
103102
num-gpus: '0'
103+
resources: '"{}"'
104104
replicas: 1
105105
template:
106106
metadata:
@@ -109,6 +109,7 @@ spec:
109109
labels:
110110
key: value
111111
spec:
112+
imagePullSecrets: []
112113
containers:
113114
- image: quay.io/rhoai/ray:2.23.0-py39-cu121
114115
lifecycle:
@@ -123,11 +124,9 @@ spec:
123124
limits:
124125
cpu: 1
125126
memory: 2G
126-
nvidia.com/gpu: 0
127127
requests:
128128
cpu: 1
129129
memory: 2G
130-
nvidia.com/gpu: 0
131130
volumeMounts:
132131
- mountPath: /etc/pki/tls/certs/odh-trusted-ca-bundle.crt
133132
name: odh-trusted-ca-cert
@@ -141,7 +140,6 @@ spec:
141140
- mountPath: /etc/ssl/certs/odh-ca-bundle.crt
142141
name: odh-ca-cert
143142
subPath: odh-ca-bundle.crt
144-
imagePullSecrets: []
145143
volumes:
146144
- configMap:
147145
items:

‎tests/unit_test.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def test_config_creation():
260260
assert config.num_workers == 2
261261
assert config.worker_cpu_requests == 3 and config.worker_cpu_limits == 4
262262
assert config.worker_memory_requests == "5G" and config.worker_memory_limits == "6G"
263-
assert config.num_worker_gpus == 7
263+
assert config.worker_extended_resource_requests == {"nvidia.com/gpu": 7}
264264
assert config.image == "quay.io/rhoai/ray:2.23.0-py39-cu121"
265265
assert config.template == f"{parent}/src/codeflare_sdk/templates/base-template.yaml"
266266
assert config.machine_types == ["cpu.small", "gpu.large"]
@@ -406,7 +406,7 @@ def test_cluster_creation_no_mcad_local_queue(mocker):
406406
worker_cpu_limits=4,
407407
worker_memory_requests=5,
408408
worker_memory_limits=6,
409-
num_worker_gpus=7,
409+
worker_extended_resource_requests={"nvidia.com/gpu": 7},
410410
machine_types=["cpu.small", "gpu.large"],
411411
image_pull_secrets=["unit-test-pull-secret"],
412412
image="quay.io/rhoai/ray:2.23.0-py39-cu121",
@@ -883,12 +883,10 @@ def test_ray_details(mocker, capsys):
883883
worker_mem_min="2G",
884884
worker_mem_max="2G",
885885
worker_cpu=1,
886-
worker_gpu=0,
887886
namespace="ns",
888887
dashboard="fake-uri",
889888
head_cpus=2,
890889
head_mem=8,
891-
head_gpu=0,
892890
)
893891
mocker.patch(
894892
"codeflare_sdk.cluster.cluster.Cluster.status",
@@ -922,7 +920,7 @@ def test_ray_details(mocker, capsys):
922920
assert ray1.worker_mem_min == ray2.worker_mem_min
923921
assert ray1.worker_mem_max == ray2.worker_mem_max
924922
assert ray1.worker_cpu == ray2.worker_cpu
925-
assert ray1.worker_gpu == ray2.worker_gpu
923+
assert ray1.worker_extended_resources == ray2.worker_extended_resources
926924
try:
927925
print_clusters([ray1, ray2])
928926
print_cluster_status(ray1)
@@ -1129,12 +1127,10 @@ def get_ray_obj(group, version, namespace, plural, cls=None):
11291127
"limits": {
11301128
"cpu": 2,
11311129
"memory": "8G",
1132-
"nvidia.com/gpu": 0,
11331130
},
11341131
"requests": {
11351132
"cpu": 2,
11361133
"memory": "8G",
1137-
"nvidia.com/gpu": 0,
11381134
},
11391135
},
11401136
"volumeMounts": [
@@ -1198,7 +1194,10 @@ def get_ray_obj(group, version, namespace, plural, cls=None):
11981194
"groupName": "small-group-quicktest",
11991195
"maxReplicas": 1,
12001196
"minReplicas": 1,
1201-
"rayStartParams": {"block": "true", "num-gpus": "0"},
1197+
"rayStartParams": {
1198+
"block": "true",
1199+
"num-gpus": "0",
1200+
},
12021201
"replicas": 1,
12031202
"scaleStrategy": {},
12041203
"template": {
@@ -1249,12 +1248,10 @@ def get_ray_obj(group, version, namespace, plural, cls=None):
12491248
"limits": {
12501249
"cpu": 1,
12511250
"memory": "2G",
1252-
"nvidia.com/gpu": 0,
12531251
},
12541252
"requests": {
12551253
"cpu": 1,
12561254
"memory": "2G",
1257-
"nvidia.com/gpu": 0,
12581255
},
12591256
},
12601257
"volumeMounts": [
@@ -1413,12 +1410,10 @@ def get_ray_obj(group, version, namespace, plural, cls=None):
14131410
"limits": {
14141411
"cpu": 2,
14151412
"memory": "8G",
1416-
"nvidia.com/gpu": 0,
14171413
},
14181414
"requests": {
14191415
"cpu": 2,
14201416
"memory": "8G",
1421-
"nvidia.com/gpu": 0,
14221417
},
14231418
},
14241419
}
@@ -1432,7 +1427,10 @@ def get_ray_obj(group, version, namespace, plural, cls=None):
14321427
"groupName": "small-group-quicktest2",
14331428
"maxReplicas": 1,
14341429
"minReplicas": 1,
1435-
"rayStartParams": {"block": "true", "num-gpus": "0"},
1430+
"rayStartParams": {
1431+
"block": "true",
1432+
"num-gpus": "0",
1433+
},
14361434
"replicas": 1,
14371435
"template": {
14381436
"metadata": {
@@ -1469,12 +1467,10 @@ def get_ray_obj(group, version, namespace, plural, cls=None):
14691467
"limits": {
14701468
"cpu": 1,
14711469
"memory": "2G",
1472-
"nvidia.com/gpu": 0,
14731470
},
14741471
"requests": {
14751472
"cpu": 1,
14761473
"memory": "2G",
1477-
"nvidia.com/gpu": 0,
14781474
},
14791475
},
14801476
}
@@ -1591,12 +1587,10 @@ def get_aw_obj(group, version, namespace, plural):
15911587
"limits": {
15921588
"cpu": 2,
15931589
"memory": "8G",
1594-
"nvidia.com/gpu": 0,
15951590
},
15961591
"requests": {
15971592
"cpu": 2,
15981593
"memory": "8G",
1599-
"nvidia.com/gpu": 0,
16001594
},
16011595
},
16021596
}
@@ -1650,12 +1644,10 @@ def get_aw_obj(group, version, namespace, plural):
16501644
"limits": {
16511645
"cpu": 1,
16521646
"memory": "2G",
1653-
"nvidia.com/gpu": 0,
16541647
},
16551648
"requests": {
16561649
"cpu": 1,
16571650
"memory": "2G",
1658-
"nvidia.com/gpu": 0,
16591651
},
16601652
},
16611653
}
@@ -1786,12 +1778,10 @@ def get_aw_obj(group, version, namespace, plural):
17861778
"limits": {
17871779
"cpu": 2,
17881780
"memory": "8G",
1789-
"nvidia.com/gpu": 0,
17901781
},
17911782
"requests": {
17921783
"cpu": 2,
17931784
"memory": "8G",
1794-
"nvidia.com/gpu": 0,
17951785
},
17961786
},
17971787
}
@@ -1845,12 +1835,10 @@ def get_aw_obj(group, version, namespace, plural):
18451835
"limits": {
18461836
"cpu": 1,
18471837
"memory": "2G",
1848-
"nvidia.com/gpu": 0,
18491838
},
18501839
"requests": {
18511840
"cpu": 1,
18521841
"memory": "2G",
1853-
"nvidia.com/gpu": 0,
18541842
},
18551843
},
18561844
}
@@ -2002,7 +1990,7 @@ def custom_side_effect(group, version, namespace, plural, **kwargs):
20021990
cluster_config.worker_memory_requests == "2G"
20031991
and cluster_config.worker_memory_limits == "2G"
20041992
)
2005-
assert cluster_config.num_worker_gpus == 0
1993+
assert cluster_config.worker_extended_resource_requests == {}
20061994
assert (
20071995
cluster_config.image
20081996
== "ghcr.io/foundation-model-stack/base:ray2.1.0-py38-gpu-pytorch1.12.0cu116-20221213-193103"
@@ -2044,7 +2032,7 @@ def test_get_cluster(mocker):
20442032
cluster_config.worker_memory_requests == "2G"
20452033
and cluster_config.worker_memory_limits == "2G"
20462034
)
2047-
assert cluster_config.num_worker_gpus == 0
2035+
assert cluster_config.worker_extended_resource_requests == {}
20482036
assert (
20492037
cluster_config.image
20502038
== "ghcr.io/foundation-model-stack/base:ray2.1.0-py38-gpu-pytorch1.12.0cu116-20221213-193103"
@@ -2082,7 +2070,7 @@ def test_get_cluster_no_mcad(mocker):
20822070
cluster_config.worker_memory_requests == "2G"
20832071
and cluster_config.worker_memory_limits == "2G"
20842072
)
2085-
assert cluster_config.num_worker_gpus == 0
2073+
assert cluster_config.worker_extended_resource_requests == {}
20862074
assert (
20872075
cluster_config.image
20882076
== "ghcr.io/foundation-model-stack/base:ray2.1.0-py38-gpu-pytorch1.12.0cu116-20221213-193103"
@@ -2310,12 +2298,10 @@ def test_cluster_status(mocker):
23102298
worker_mem_min=2,
23112299
worker_mem_max=2,
23122300
worker_cpu=1,
2313-
worker_gpu=0,
23142301
namespace="ns",
23152302
dashboard="fake-uri",
23162303
head_cpus=2,
23172304
head_mem=8,
2318-
head_gpu=0,
23192305
)
23202306
cf = Cluster(
23212307
ClusterConfiguration(
@@ -2806,6 +2792,24 @@ def test_rjc_list_jobs(ray_job_client, mocker):
28062792
assert job_list_jobs == jobs_list
28072793

28082794

2795+
def test_cluster_config_deprecation_conversion(mocker):
2796+
config = ClusterConfiguration(
2797+
name="test",
2798+
num_gpus=2,
2799+
head_gpus=1,
2800+
min_memory=3,
2801+
max_memory=4,
2802+
min_cpus=1,
2803+
max_cpus=2,
2804+
)
2805+
assert config.worker_extended_resource_requests == {"nvidia.com/gpu": 2}
2806+
assert config.head_extended_resource_requests == {"nvidia.com/gpu": 1}
2807+
assert config.worker_memory_requests == "3G"
2808+
assert config.worker_memory_limits == "4G"
2809+
assert config.worker_cpu_requests == 1
2810+
assert config.worker_cpu_limits == 2
2811+
2812+
28092813
# Make sure to always keep this function last
28102814
def test_cleanup():
28112815
os.remove(f"{aw_dir}unit-test-no-kueue.yaml")

‎tests/unit_test_support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def createClusterConfig():
1313
worker_cpu_limits=4,
1414
worker_memory_requests=5,
1515
worker_memory_limits=6,
16-
num_worker_gpus=7,
16+
worker_extended_resource_requests={"nvidia.com/gpu": 7},
1717
appwrapper=True,
1818
machine_types=["cpu.small", "gpu.large"],
1919
image_pull_secrets=["unit-test-pull-secret"],

‎tests/upgrade/raycluster_sdk_upgrade_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ def run_mnist_raycluster_sdk_oauth(self):
5454
worker_cpu_limits=1,
5555
worker_memory_requests=1,
5656
worker_memory_limits=2,
57-
num_worker_gpus=0,
5857
image=ray_image,
5958
write_to_file=True,
6059
verify_tls=False,

0 commit comments

Comments
 (0)
Please sign in to comment.