Skip to content

Commit 5965fbe

Browse files
authored
Merge branch 'main' into ODSC-76209/GPU-Shape-Recommendation
2 parents 8593b02 + 60b0693 commit 5965fbe

File tree

6 files changed

+368
-84
lines changed

6 files changed

+368
-84
lines changed

ads/aqua/common/entities.py

Lines changed: 107 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55
import re
6-
from typing import Any, Dict, List, Optional
6+
from typing import Any, Dict, List, Optional, Union
77

88
from oci.data_science.models import Model
99
from pydantic import BaseModel, ConfigDict, Field, model_validator
@@ -245,55 +245,71 @@ class AquaMultiModelRef(Serializable):
245245
"""
246246
Lightweight model descriptor used for multi-model deployment.
247247
248-
This class only contains essential details
249-
required to fetch complete model metadata and deploy models.
248+
This class holds essential details required to fetch model metadata and deploy
249+
individual models as part of a multi-model deployment group.
250250
251251
Attributes
252252
----------
253253
model_id : str
254-
The unique identifier of the model.
254+
The unique identifier (OCID) of the base model.
255255
model_name : Optional[str]
256-
The name of the model.
256+
Optional name for the model.
257257
gpu_count : Optional[int]
258-
Number of GPUs required for deployment.
258+
Number of GPUs required to allocate for this model during deployment.
259259
model_task : Optional[str]
260-
The task that model operates on. Supported tasks are in MultiModelSupportedTaskType
260+
The machine learning task this model performs (e.g., text-generation, summarization).
261+
Supported values are listed in `MultiModelSupportedTaskType`.
261262
env_var : Optional[Dict[str, Any]]
262-
Optional environment variables to override during deployment.
263+
Optional dictionary of environment variables to inject into the runtime environment
264+
of the model container.
265+
params : Optional[Dict[str, Any]]
266+
Optional dictionary of container-specific inference parameters to override.
267+
These are typically framework-level flags required by the runtime backend.
268+
For example, in vLLM containers, valid params may include:
269+
`--tensor-parallel-size`, `--enforce-eager`, `--max-model-len`, etc.
263270
artifact_location : Optional[str]
264-
Artifact path of model in the multimodel group.
271+
Relative path or URI of the model artifact inside the multi-model group folder.
265272
fine_tune_weights : Optional[List[LoraModuleSpec]]
266-
For fine tuned models, the artifact path of the modified model weights
273+
List of fine-tuned weight artifacts (e.g., LoRA modules) associated with this model.
267274
"""
268275

269276
model_id: str = Field(..., description="The model OCID to deploy.")
270-
model_name: Optional[str] = Field(None, description="The name of model.")
277+
model_name: Optional[str] = Field(None, description="The name of the model.")
271278
gpu_count: Optional[int] = Field(
272-
None, description="The gpu count allocation for the model."
279+
None, description="The number of GPUs allocated for the model."
273280
)
274281
model_task: Optional[str] = Field(
275282
None,
276-
description="The task that model operates on. Supported tasks are in MultiModelSupportedTaskType",
283+
description="The task this model performs. See `MultiModelSupportedTaskType` for supported values.",
277284
)
278285
env_var: Optional[dict] = Field(
279-
default_factory=dict, description="The environment variables of the model."
286+
default_factory=dict,
287+
description="Environment variables to override during container startup.",
288+
)
289+
params: Optional[dict] = Field(
290+
default_factory=dict,
291+
description=(
292+
"Framework-specific startup parameters required by the container runtime. "
293+
"For example, vLLM models may use flags like `--tensor-parallel-size`, `--enforce-eager`, etc."
294+
),
280295
)
281296
artifact_location: Optional[str] = Field(
282-
None, description="Artifact path of model in the multimodel group."
297+
None,
298+
description="Path to the model artifact relative to the multi-model base folder.",
283299
)
284300
fine_tune_weights: Optional[List[LoraModuleSpec]] = Field(
285301
None,
286-
description="For fine tuned models, the artifact path of the modified model weights",
302+
description="List of fine-tuned weight modules (e.g., LoRA) associated with this base model.",
287303
)
288304

289305
def all_model_ids(self) -> List[str]:
290306
"""
291-
Returns all associated model OCIDs, including the base model and any fine-tuned models.
307+
Returns all model OCIDs associated with this reference, including fine-tuned weights.
292308
293309
Returns
294310
-------
295311
List[str]
296-
A list of all model OCIDs associated with this multi-model reference.
312+
A list containing the base model OCID and any fine-tuned module OCIDs.
297313
"""
298314
ids = {self.model_id}
299315
if self.fine_tune_weights:
@@ -302,8 +318,80 @@ def all_model_ids(self) -> List[str]:
302318
)
303319
return list(ids)
304320

321+
@model_validator(mode="before")
322+
@classmethod
323+
def extract_params_from_env_var(cls, values: Dict[str, Any]) -> Dict[str, Any]:
324+
"""
325+
A model-level validator that extracts `PARAMS` from the `env_var` dictionary
326+
and injects them into the `params` field as a dictionary.
327+
328+
This is useful for backward compatibility where users pass CLI-style
329+
parameters via environment variables, e.g.:
330+
env_var = { "PARAMS": "--max-model-len 65536 --enable-streaming" }
331+
332+
If `params` is already set, values from `PARAMS` in `env_var` are added
333+
only if they do not override existing keys.
334+
"""
335+
env = values.get("env_var", {})
336+
param_string = env.pop("PARAMS", None)
337+
338+
if param_string:
339+
parsed_params = cls._parse_params(params=param_string)
340+
existing_params = values.get("params", {}) or {}
341+
# Avoid overriding existing keys
342+
for k, v in parsed_params.items():
343+
if k not in existing_params:
344+
existing_params[k] = v
345+
values["params"] = existing_params
346+
values["env_var"] = env # cleaned up version without PARAMS
347+
348+
return values
349+
350+
@staticmethod
351+
def _parse_params(params: Union[str, List[str]]) -> Dict[str, str]:
352+
"""
353+
Parses CLI-style parameters into a dictionary format.
354+
355+
This method accepts either:
356+
- A single string of parameters (e.g., "--key1 val1 --key2 val2")
357+
- A list of strings (e.g., ["--key1", "val1", "--key2", "val2"])
358+
359+
Returns a dictionary of the form { "key1": "val1", "key2": "val2" }.
360+
361+
Parameters
362+
----------
363+
params : Union[str, List[str]]
364+
The parameters to parse. Can be a single string or a list of strings.
365+
366+
Returns
367+
-------
368+
Dict[str, str]
369+
Dictionary with parameter names as keys and their corresponding values as strings.
370+
"""
371+
if not params or not isinstance(params, (str, list)):
372+
return {}
373+
374+
# Normalize string to list of "--key value" strings
375+
if isinstance(params, str):
376+
params_list = [
377+
f"--{param.strip()}" for param in params.split("--") if param.strip()
378+
]
379+
else:
380+
params_list = params
381+
382+
parsed = {}
383+
for item in params_list:
384+
parts = item.strip().split()
385+
if not parts:
386+
continue
387+
key = parts[0]
388+
value = " ".join(parts[1:]) if len(parts) > 1 else ""
389+
parsed[key] = value
390+
391+
return parsed
392+
305393
class Config:
306-
extra = "ignore"
394+
extra = "allow"
307395
protected_namespaces = ()
308396

309397

0 commit comments

Comments
 (0)