diff --git a/ads/aqua/cli.py b/ads/aqua/cli.py index 07e83c974..2a9db38ce 100644 --- a/ads/aqua/cli.py +++ b/ads/aqua/cli.py @@ -14,6 +14,7 @@ from ads.aqua.finetuning import AquaFineTuningApp from ads.aqua.model import AquaModelApp from ads.aqua.modeldeployment import AquaDeploymentApp +from ads.aqua.shaperecommend.recommend import AquaRecommendApp from ads.common.utils import LOG_LEVELS @@ -29,6 +30,7 @@ class AquaCommand: fine_tuning = AquaFineTuningApp deployment = AquaDeploymentApp evaluation = AquaEvaluationApp + recommend = AquaRecommendApp def __init__( self, @@ -94,18 +96,20 @@ def _validate_value(flag, value): "If you intend to chain a function call to the result, please separate the " "flag and the subsequent function call with separator `-`." ) - + @staticmethod def install(): """Install ADS Aqua Extension from wheel file. Set enviroment variable `AQUA_EXTENSTION_PATH` to change the wheel file path. - Return + Return ------ int: Installatation status. """ import subprocess - wheel_file_path = os.environ.get("AQUA_EXTENSTION_PATH", "/ads/extension/adsjupyterlab_aqua_extension*.whl") - status = subprocess.run(f"pip install {wheel_file_path}",shell=True) - return status.check_returncode \ No newline at end of file + wheel_file_path = os.environ.get( + "AQUA_EXTENSTION_PATH", "/ads/extension/adsjupyterlab_aqua_extension*.whl" + ) + status = subprocess.run(f"pip install {wheel_file_path}", shell=True) + return status.check_returncode diff --git a/ads/aqua/extension/__init__.py b/ads/aqua/extension/__init__.py index 4c8d9f3f3..400bab5c3 100644 --- a/ads/aqua/extension/__init__.py +++ b/ads/aqua/extension/__init__.py @@ -12,6 +12,7 @@ ) from ads.aqua.extension.evaluation_handler import __handlers__ as __eval_handlers__ from ads.aqua.extension.finetune_handler import __handlers__ as __finetune_handlers__ +from ads.aqua.extension.gpu_recommend_handler import __handlers__ as __gpu_handlers__ from ads.aqua.extension.model_handler import __handlers__ as __model_handlers__ from ads.aqua.extension.ui_handler import __handlers__ as __ui_handlers__ from ads.aqua.extension.ui_websocket_handler import __handlers__ as __ws_handlers__ @@ -24,6 +25,7 @@ + __ui_handlers__ + __eval_handlers__ + __ws_handlers__ + + __gpu_handlers__ ) diff --git a/ads/aqua/extension/recommend_handler.py b/ads/aqua/extension/recommend_handler.py new file mode 100644 index 000000000..1a98453c0 --- /dev/null +++ b/ads/aqua/extension/recommend_handler.py @@ -0,0 +1,50 @@ + +from tornado.web import HTTPError + +from ads.aqua.common.decorator import handle_exceptions +from ads.aqua.extension.base_handler import AquaAPIhandler +from ads.aqua.extension.errors import Errors +from ads.aqua.shaperecommend.recommend import AquaRecommendApp +from ads.config import COMPARTMENT_OCID + + +class AquaRecommendHandler(AquaAPIhandler): + """ + Handler for Aqua GPU Recommendation REST APIs. + + Methods + ------- + get(self, id: Union[str, List[str]]) + Retrieves a list of AQUA deployments or model info or logs by ID. + post(self, *args, **kwargs) + Obtains the eligible compute shapes that would fit the specifed model, context length, model weights, and quantization level. + + Raises + ------ + HTTPError: For various failure scenarios such as invalid input format, missing data, etc. + """ + + @handle_exceptions + def post(self, *args, **kwargs): # noqa: ARG002 + """ + Lists the eligible GPU compute shapes for the specifed model. + + Returns + ------- + List[ComputeShapeSummary]: + The list of the model deployment shapes. + """ + try: + input_data = self.get_json_body() + # input_data["compartment_id"] = self.get_argument("compartment_id", default=COMPARTMENT_OCID) + except Exception as ex: + raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex + + if not input_data: + raise HTTPError(400, Errors.NO_INPUT_DATA) + + self.finish(AquaRecommendApp().which_gpu(**input_data)) + +__handlers__ = [ + ("gpu-shape-recommendation/?([^/]*)", AquaRecommendHandler), +] diff --git a/ads/aqua/shaperecommend/__init__.py b/ads/aqua/shaperecommend/__init__.py new file mode 100644 index 000000000..3297935f2 --- /dev/null +++ b/ads/aqua/shaperecommend/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +# Copyright (c) 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +from ads.aqua.shaperecommend.recommend import AquaGPURecommendApp + +__all__ = ["AquaGPURecommendApp"] diff --git a/ads/aqua/shaperecommend/constants.py b/ads/aqua/shaperecommend/constants.py new file mode 100644 index 000000000..22bc6d556 --- /dev/null +++ b/ads/aqua/shaperecommend/constants.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +""" +aqua.shaperecommend.constants +~~~~~~~~~~~~~~ + +This module contains constants used in Aqua GPU Recommendation for Models. + +LLAMA_REQUIRED_FIELDS refer to fields necessary for calculating model memory for GQA Architecture Models + +MOE_REQUIRED_FIELDS refer to fields necessary for Mixture of Experts (MoE) Architecture Models + +NEXT_QUANT suggests the next quantization level based on the current quantization (if applied) or the model weights (if no quantization yet) +""" +LLAMA_REQUIRED_FIELDS = [ + "num_hidden_layers", "hidden_size", "num_attention_heads", + "num_key_value_heads", "head_dim", "intermediate_size", "vocab_size" +] + +MOE_REQUIRED_FIELDS = LLAMA_REQUIRED_FIELDS + [ + "num_local_experts", "intermediate_size" +] + +NEXT_QUANT = { + "float32": ["bfloat16", "float16", "int8"], + "bfloat16": ["float16", "int8"], + "float16": ["int8"], + "int8": ["8bit", "4bit (Not Recommended)"], + "8bit": ["4bit (Not Recommended)"], + "4bit": ["No smaller quantization available"] +} diff --git a/ads/aqua/shaperecommend/estimator.py b/ads/aqua/shaperecommend/estimator.py new file mode 100644 index 000000000..5456d4122 --- /dev/null +++ b/ads/aqua/shaperecommend/estimator.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python +# Copyright (c) 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +from typing import Optional + +from pydantic import BaseModel, Field + +from ads.aqua.app import logger +from ads.aqua.shaperecommend.constants import LLAMA_REQUIRED_FIELDS, MOE_REQUIRED_FIELDS +from ads.aqua.shaperecommend.llm_config import LLMConfig + + +class MemoryEstimator(BaseModel): + """ + The generic estimator for Transformer Architecture models (OPT/ Bloom) + Used as a fallback estimator if model identified is not a MoE or GQA Architecture Model. + Has properties to estimate the KV Cache size, Model size, and total footprint (KV Cache + Model size) + """ + + llm_config: LLMConfig = Field( + ..., + description="The model's config.json file with the necessary parameters for model size and KV cache estimation/", + ) + batch_size: int = ( + 1 # we assume that estimation for batch sizes are not supported yet + ) + seq_len: Optional[int] = Field( + 4096, description="The max-seq-len to estimate the size of the KV cache." + ) + + @property + def kv_cache_memory(self) -> float: + """ + Estimates the KV cache size (in GB) using the LLM config.json parameters. + + Uses num_attention_heads (assumes no GQA, each attention head has its own query, key, value) for estimation + """ + seq_len = self.seq_len or self.llm_config.max_seq_len + c = self.llm_config + kv_cache_dtype_bytes = ( + c.bytes_per_parameter + ) # vLLM uses model's weight/quantization applied to KV cache + + total_bytes = ( + self.batch_size + * c.num_hidden_layers + * 2 + * c.num_attention_heads + * seq_len + * c.head_dim + * kv_cache_dtype_bytes + ) + return total_bytes / 1e9 + + @property + def model_memory(self) -> float: + """ + Estimates the model size (in GB) based on estimating the model parameter size and model weights + + Model Parameter estimation: Standard decoder-only, untied/tied embeddings possible + """ + c = self.llm_config + embedding_count = 1 if getattr(c, "tie_word_embeddings", True) else 2 + embedding_params = ( + embedding_count * c.vocab_size * c.hidden_size + ) # input and output untied + layer_params = 12 * c.num_hidden_layers * (c.hidden_size**2) # GPT-style + num_params = layer_params + embedding_params + + return num_params * c.bytes_per_parameter / 1e9 + + # @property + # def model_overhead(self) -> float: + # overhead = max(1, math.ceil(0.0 * self.model_memory)) + # return overhead + + @property + def total_memory(self) -> float: + """ + Computes the total memory footprint of the model (KV cache & model size from estimated parameters) + """ + return self.model_memory + self.kv_cache_memory + + +# Specialized estimators: +class LlamaMemoryEstimator(MemoryEstimator): + """ + Estimator for GQA-type architectures. Handles tied (memory savings) and untied embeddings, + and uses grouped attention (GQA) for more efficient KV cache memory estimation. + + KV cache: Use num_attention_heads (assumes GQA) + Model Parameter estimation: Standard decoder-only, untied/tied embeddings possible + """ + + @property + def model_memory(self) -> float: + """ + Returns estimated model parameter memory (in GB), accurately accounting + for Llama-style attention and MLP, and tied or untied embeddings. + """ + c = self.llm_config + + embedding_params, attn_params = self._calc_attn_embed_params() + + # MLP params + gate_proj = c.hidden_size * c.intermediate_size + up_proj = c.hidden_size * c.intermediate_size + down_proj = c.intermediate_size * c.hidden_size + mlp_params = gate_proj + up_proj + down_proj + + # Total per-layer + layer_params = attn_params + mlp_params + # Total params + num_params = c.num_hidden_layers * layer_params + embedding_params + return num_params * c.bytes_per_parameter / 1e9 + + @property + def kv_cache_memory(self) -> float: + """ + Returns estimated KV cache memory in GB for GQA models. + + Grouped Query Attention uses num_key_value_heads, which groups of Q heads share a K and V projection. + num_key_value_heads < num_attention_heads, which reduces the KV Cache size. + """ + c = self.llm_config + seq_len = self.seq_len or getattr(c, "max_seq_len", 2048) + kv_cache_dtype_bytes = c.bytes_per_parameter + kv_heads = c.num_key_value_heads + + total_bytes = ( + self.batch_size + * c.num_hidden_layers + * 2 + * kv_heads + * seq_len + * c.head_dim + * kv_cache_dtype_bytes + ) + return total_bytes / 1e9 + + def _calc_attn_embed_params(self) -> tuple: + """ + Returns the embedding parameter count and attention parameter count for Llama-family (GQA) models. + """ + c = self.llm_config + + # Embedding parameters + # assume tied embeddings unless tie_word_embeddings = False + embedding_count = 1 if getattr(c, "tie_word_embeddings", True) else 2 + embedding_params = embedding_count * c.vocab_size * c.hidden_size + + q_proj = c.hidden_size * c.hidden_size + k_proj = c.hidden_size * (c.num_key_value_heads * c.head_dim) + v_proj = c.hidden_size * (c.num_key_value_heads * c.head_dim) + o_proj = c.hidden_size * c.hidden_size + attn_params = q_proj + k_proj + v_proj + o_proj + + return embedding_params, attn_params + + +class MixtureMemoryEstimator(LlamaMemoryEstimator): + """ + Estimator for Mixture-of-Experts (MoE) architectures (e.g., Mixtral, MoE Llama). + Adds extra expert parallelism block parameter count to LlamaMemoryEstimator logic. + """ + + @property + def model_memory(self) -> float: + """ + Accounts for the increase in model parameters due to additional expert MLP blocks in MoE Models. + + Returns the estimated memory size of the MoE Model (in GB). + """ + c = self.llm_config + # Attention parameter count (Llama-style) + embedding_params, attn_params = self._calc_attn_embed_params() + + # MoE MLP params per layer + moe_params_per_layer = ( + c.num_local_experts * 3 * c.hidden_size * c.intermediate_size + ) + total_params = ( + c.num_hidden_layers * (attn_params + moe_params_per_layer) + + embedding_params + ) + + # Convert to GB + return total_params * c.bytes_per_parameter / 1e9 + + +def get_estimator(llm_config, **kwargs) -> MemoryEstimator: + """ + Extracts the correct estimator based on the defined parameters in the config.json + See constants.py for LLMConfig parameters necessary for specific estimators. + Uses MemoryEstimator as a fallback if parameters needed for GQA and MoE Architectures are missing. + + Returns the appropriate MemoryEstimator based on the fields defined by the model's config.json (as represented by LLMConfig). + """ + if all( + hasattr(llm_config, f) and getattr(llm_config, f) is not None + for f in MOE_REQUIRED_FIELDS + ): + return MixtureMemoryEstimator(llm_config=llm_config, **kwargs) + elif all( + hasattr(llm_config, f) and getattr(llm_config, f) is not None + for f in LLAMA_REQUIRED_FIELDS + ): + return LlamaMemoryEstimator(llm_config=llm_config, **kwargs) + else: + logger.warning( + "Falling back to generic GPT estimator: required fields missing from config.json file in model." + ) + return MemoryEstimator(llm_config=llm_config, **kwargs) diff --git a/ads/aqua/shaperecommend/llm_config.py b/ads/aqua/shaperecommend/llm_config.py new file mode 100644 index 000000000..b902280b3 --- /dev/null +++ b/ads/aqua/shaperecommend/llm_config.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python +# Copyright (c) 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +import re +from typing import List, Optional + +from pydantic import BaseModel, Field + +from ads.aqua.shaperecommend.constants import NEXT_QUANT + + +class LLMConfig(BaseModel): + """ + Standardized configuration object for evaluating the size of Large Language Models (LLMs) + based on their config.json file. + Architecture is determined by which non-required fields are defined below. + """ + + num_hidden_layers: int = Field( + ..., + description="Number of transformer blocks (layers) in the model’s neural network stack.", + ) + hidden_size: int = Field( + ..., description="Embedding dimension or hidden size of each layer." + ) + vocab_size: int = Field(..., description="Vocabulary size for input/output tokens.") + num_attention_heads: int = Field( + ..., + description="Number of attention heads (used for queries and to determine head_dim).", + ) + + head_dim: int = Field( + ..., + description="Dimension of each attention head. Typically hidden_size // num_attention_heads.", + ) + max_seq_len: int = Field( + ..., description="Maximum input sequence length (context window)." + ) + weight_dtype: Optional[str] = Field( + "float32", description="Parameter data type: 'float32', 'float16', etc." + ) + quantization: Optional[str] = Field( + None, + description="Quantization method (e.g., '8bit', '4bit', 'gptq', 'awq') or None if unquantized.", + ) + + num_key_value_heads: Optional[int] = Field( + None, + description="Number of key/value heads (for GQA architectures: Llama, Mistral, Falcon, Qwen, etc.). Used to determine KV cache size", + ) + + num_local_experts: Optional[int] = Field( + None, description="For MoE architectures, the number of experts per MoE layer" + ) + intermediate_size: Optional[int] = Field( + None, description="For MoE architectures, size of the MLP activation layer." + ) + + tie_word_embeddings: Optional[bool] = Field(None) + + @property + def bytes_per_parameter(self) -> float: + """ + Returns the number of bytes used to store a model parameter, + accounting for quantization or weight storage type. + """ + mapping = { + "float32": 4, + "bfloat16": 2, + "float16": 2, + "fp16": 2, + "half": 2, + "int8": 1, + "8bit": 1, + "4bit": 0.5, + "awq": 0.5, + "gptq": 0.5, + } + # Quantization takes precedence + q = (self.quantization or "").lower() + if q in mapping: + return mapping[q] + if "bit" in q: + m = re.match(r"(\d+)bit", q) + if m: + bits = int(m[1]) + return bits / 8 # bytes per parameter + # Unknown bit type: fallback + return 1 + # Fallback to weight_dtype mapping + dtype = (self.weight_dtype or "float32").lower() + if dtype in mapping: + return mapping[dtype] + return mapping["float32"] # Default + + @classmethod + def detect_quantization(cls, raw: dict) -> Optional[str]: + """ + Detects main quantization types from Hugging Face config dict. + """ + if raw.get("load_in_8bit"): + return "8bit" + if raw.get("load_in_4bit"): + return "4bit" + if "quantization_config" in raw: + qcfg = raw["quantization_config"] + if "gptq" in str(qcfg).lower(): + return "gptq" + if "awq" in str(qcfg).lower(): + return "awq" + bits = qcfg.get("bits") or qcfg.get("wbits") + if bits: + return f"{bits}bit" + return "custom-quant" + return None + + @property + def suggested_quantizations(self) -> List[str]: + """ + Suggests the next quantization level to use based on the current quantization level if available. + Model weights as fallback if no quantization is currently applied. + """ + key = (self.quantization or self.weight_dtype or "float32").lower() + return NEXT_QUANT.get(key, []) + + @classmethod + def from_raw_config(cls, raw: dict) -> "LLMConfig": + """ + Instantiates an LLMConfig from a raw Hugging Face config.json file, + using robust key detection (considers multiple possibilities for keys referring to the same model attribute). + """ + + # Field mappings with fallback + num_hidden_layers = ( + raw.get("num_hidden_layers") or raw.get("n_layer") or raw.get("num_layers") + ) + hidden_size = raw.get("hidden_size") or raw.get("n_embd") or raw.get("d_model") + vocab_size = raw.get("vocab_size") + weight_dtype = str(raw.get("torch_dtype", "float32")) + quantization = cls.detect_quantization(raw) + num_key_value_heads = ( + raw.get("num_key_value_heads") # GQA models (ex. Llama-type) + ) + + num_attention_heads = ( + raw.get("num_attention_heads") or raw.get("n_head") or raw.get("num_heads") + ) + + head_dim = raw.get("head_dim") or ( + int(hidden_size) // int(num_attention_heads) + if hidden_size and num_attention_heads + else None + ) + max_seq_len = ( + raw.get("max_position_embeddings") + or raw.get("n_positions") + or raw.get("max_seq_len") + ) + + num_local_experts = ( + raw.get("num_local_experts") + or raw.get("n_routed_experts") + or raw.get("num_experts") + ) + intermediate_size = raw.get("moe_intermediate_size") or raw.get( + "intermediate_size" + ) + + # Type safety: minimal assertion + if None in [ + num_hidden_layers, + hidden_size, + vocab_size, + num_attention_heads, + head_dim, + max_seq_len, + ]: + raise ValueError("Missing required value in model config.") + + return cls( + num_hidden_layers=int(num_hidden_layers), + hidden_size=int(hidden_size), + num_attention_heads=int(num_attention_heads), + num_key_value_heads=num_key_value_heads, + head_dim=int(head_dim), + vocab_size=int(vocab_size), + weight_dtype=weight_dtype, + quantization=quantization, + max_seq_len=int(max_seq_len), + num_local_experts=num_local_experts, + intermediate_size=intermediate_size, + ) diff --git a/ads/aqua/shaperecommend/recommend.py b/ads/aqua/shaperecommend/recommend.py new file mode 100644 index 000000000..9257ced05 --- /dev/null +++ b/ads/aqua/shaperecommend/recommend.py @@ -0,0 +1,381 @@ +import json +from typing import List + +from huggingface_hub import hf_hub_download +from pydantic import ValidationError + +from ads.aqua.app import AquaApp, logger +from ads.aqua.common.entities import ComputeShapeSummary +from ads.aqua.common.errors import AquaValueError +from ads.aqua.common.utils import build_pydantic_error_message, list_hf_models +from ads.aqua.modeldeployment.deployment import AquaDeploymentApp +from ads.aqua.shaperecommend.constants import NEXT_QUANT +from ads.aqua.shaperecommend.estimator import MemoryEstimator, get_estimator +from ads.aqua.shaperecommend.llm_config import LLMConfig +from ads.aqua.shaperecommend.shape_report import ( + DeploymentShapeSummary, + GPUSummary, + RequestRecommend, + ShapeRecommendationReport, + ShapeSummary, + TroubleshootShapeSummary, +) + + +class AquaRecommendApp(AquaApp): + """ + Interface for recommending GPU shapes for machine learning model deployments + on Oracle Cloud Infrastructure Data Science service. + + This class provides methods to recommend deployment shapes based on a model's requirements, + handle recommendation details and troubleshooting, and retrieve specific OCI Machine Learning shapes. + Must be used within a properly configured and authenticated OCI environment. + + Methods + ------- + which_gpu(self, **kwargs) -> List[Dict]: + Lists the valid GPU deployment shapes that fit the given model and user-provided settings. + + Note: + Use `ads aqua recommend which_gpu --help` to get more details on available parameters. + """ + + def which_gpu(self, **kwargs) -> ShapeRecommendationReport: + """ + Lists valid GPU deployment shapes for the provided model and configuration. + + Validates input, retrieves the model configuration, checks the requested sequence length, + identifies available and valid compute shapes, and summarizes which shapes are compatible + with the current model settings. + + Parameters + ---------- + model : str + Name of the model to deploy. + max_model_len : int, optional + Maximum sequence length/user context length the model should support. + + Returns + ------- + ShapeRecommendationReport + A recommendation report with compatible deployment shapes, or troubleshooting info + if no shape is suitable. + + Raises + ------ + AquaValueError + If parameters are missing or invalid, or if no valid sequence length is requested. + """ + try: + request = RequestRecommend(**kwargs) + config = self.get_model_config(request.model) + except ValidationError as ex: + custom_errors = build_pydantic_error_message(ex) + raise AquaValueError( + f"Invalid parameters for creating a model deployment. Error details: {custom_errors}." + ) from ex + + valid_seq_lens = self.power_of_two_seq_lens(max_len=config.max_seq_len) + if request.max_model_len not in valid_seq_lens: + valid_seq_lens_str = " ".join(map(str, valid_seq_lens)) + raise AquaValueError( + f"Invalid model sequence length requested. Please select one model sequence length: {valid_seq_lens_str}" + ) + + available_shapes = AquaDeploymentApp().list_shapes() + valid_gpu_shapes = self.valid_compute_shapes(available_shapes) + return self.summarize_shapes_for_seq_lens( + config, valid_gpu_shapes, user_seq_len=request.max_model_len + ) + + def get_model_config(self, model_name: str) -> LLMConfig: + """ + Downloads config.json for a model from Hugging Face and parses it into an LLMConfig instance. + Handles errors gracefully if the model or config cannot be retrieved. + """ + + model_ids = list_hf_models(model_name) + if not model_ids: + raise AquaValueError( + f"No models found for your query: '{model_name}'." + ) + + model_id = model_ids[0] # Select the first model from the list + + try: + config_path = hf_hub_download(repo_id=model_id, filename="config.json") + with open(config_path, encoding="utf-8") as f: + config_data = json.load(f) + return LLMConfig(**config_data) + + except Exception as ex: + raise AquaValueError( + f"Error retrieving or parsing config.json for model '{model_name}': {ex}" + ) from ex + + @staticmethod + def valid_compute_shapes( + compute_shapes: List["ComputeShapeSummary"] + ) -> List["ComputeShapeSummary"]: + """ + Returns a filtered list of ComputeShapeSummary objects that are considered valid. + + A shape is valid if: + - It has a non-empty name, + - gpu_specs is present, + - gpu_memory_in_gbs and gpu_count are present in gpu_specs. + + Args: + compute_shapes: List of ComputeShapeSummary objects to validate. + + Returns: + List of ComputeShapeSummary objects passing the above checks. + """ + return [ + shape + for shape in compute_shapes + if shape.name + and getattr(shape, "gpu_specs", None) + and getattr(shape.gpu_specs, "gpu_memory_in_gbs", None) + and getattr(shape.gpu_specs, "gpu_count", None) + ] + + @staticmethod + def power_of_two_seq_lens(min_len=2048, max_len=16384) -> List[int]: + """ + Calculates the range of valid sequence lengths (power of two) up until + the model's max sequence length as specified in the LLMConfig. + """ + vals = [] + curr = min_len + while curr <= max_len: + vals.append(curr) + curr *= 2 + if vals[-1] != max_len: + vals.append(max_len) + return vals + + def suggest_param_advice(self, estimator: MemoryEstimator, allowed) -> str: + """ + Returns a tailored suggestion on how the user should improve the memory footprint of their model. + Identifies whether the KV cache and/or model size is the bottleneck for memory footprint. + """ + kv_gb = estimator.kv_cache_memory + wt_gb = estimator.model_memory + batch_size = estimator.batch_size + seq_len = estimator.seq_len + weight_size = estimator.config.weight_dtype + suggested_quant_msg = None + + if estimator.config.suggested_quantizations: + quant_advice = ", ".join(estimator.config.suggested_quantizations) + suggested_quant_msg = f"Use the same model with {quant_advice} quantization" + + kv_advice = ( + f"To reduce KV cache memory usage: \n" + f"1. reduce maximum context length (set --max-model-len to less than current max sequence length: {seq_len})\n" + f"2. reduce batch size to less than {batch_size}." + if batch_size > 1 + else ".\n" + ) + + wt_advice = ( + f"To reduce model size:\n" + f"1. Consider using a model with fewer parameters. \n" + f"2. {suggested_quant_msg or 'a quantized version (e.g., INT8 or another supported type)'}, which is smaller than the current quantization/ weight size: {estimator.config.quantization if estimator.config.quantization in NEXT_QUANT.keys() else weight_size}." + ) + + if kv_gb > wt_gb and kv_gb > allowed * 0.5: + # KV cache drives memory usage + main = "KV cache memory usage is the bottleneck of memory use." + advice = kv_advice + + elif wt_gb > kv_gb and wt_gb > allowed * 0.5: + # model weights drives memory usage + main = "The model configuration is the bottleneck of memory use." + advice = wt_advice + + else: + main = "Both model weights and KV cache are significant contributors to memory use." + advice = kv_advice + "\n" + wt_advice + return f"{main} ({kv_gb:.1f}GB KV cache, {wt_gb:.1f}GB weights).\n{advice}" + + def limiting_factor( + self, + estimator: MemoryEstimator, + available_ram, + gpu_utilization: float, + warn_delta=0.9, + ) -> str: + """ + Warns the user if a certain valid compute shape would be close to the memory limit if model w/ current parameters was used. + Uses the suggestions from suggest_param_advice to give tailored warnings. + """ + required = estimator.total_memory + allowed = available_ram * gpu_utilization + + quantization = getattr(estimator.config, "quantization", "None") + weight_size = estimator.config.weight_dtype + batch_size = estimator.batch_size + seq_len = estimator.seq_len + + param_advice = self.suggest_param_advice(estimator, allowed) + + # even if model configuration works, if we are close to the limit, we should warn user + if required > allowed * warn_delta: + advice = ( + f"The selected model configuration is close to GPU Memory Limit ({required:.1f}GB used / {allowed:.1f}GB allowed).\n" + + param_advice + ) + return advice + else: + return ( + f"Model fits well within limits of compute shape. ({required:.1f}GB used / {allowed:.1f}GB allowed)\n" + f"(Current batch size: {batch_size}, context length: {seq_len}, " + f"quantization/ model weight size: {quantization or weight_size})." + ) + + def calc_gpu_report_per_shape( + self, + estimator: MemoryEstimator, + shape: ComputeShapeSummary, + gpu_utilization: float, + ) -> ShapeSummary: + """ + Generate a summary of GPU memory and compute usage for a specific shape configuration. + + For a given compute shape, evaluates all powers-of-two allocations of available GPUs, + and for each valid configuration (where total available GPU memory exceeds model requirements), + generates a `GPUSummary` describing per-GPU memory allocation and the system's limiting factor. + + Parameters: + estimator (MemoryEstimator): The memory estimator object containing model memory requirements. + shape (ComputeShapeSummary): The compute shape configuration, including GPU specs. + gpu_utilization (float): The fraction (0.0–1.0) of total GPU memory to consider usable. + + Returns: + ShapeSummary: A summary object containing the shape name and a list of valid `GPUSummary` + entries. Returns `None` if no valid GPU configurations are possible for the shape + and utilization provided. + """ + power = 1 + + limit = shape.gpu_specs.gpu_count + num_gpu_cards = [] + + # get eligible number of cards + while limit and power <= limit: + num_gpu_cards.append(power) + power *= 2 + + # take gpu_memory_in_gbs / s.gpu_specs.gpu_count -> gpu_memory/ gpu card * used_gps -> available ram + memory_per_gpu = shape.gpu_specs.gpu_memory_in_gbs / shape.gpu_specs.gpu_count + + gpu_reports = [] + + for used_gpus in num_gpu_cards: + available_ram = used_gpus * memory_per_gpu + eligible = available_ram * gpu_utilization > estimator.total_memory + if eligible: + limit = self.limiting_factor(estimator, available_ram, gpu_utilization) + gpu_reports.append( + GPUSummary( + gpu_count=used_gpus, + gpu_memory_in_gb=available_ram, + limiting_factor=limit, + ) + ) + + return ( + ShapeSummary(shape=shape.name, gpu_reports=gpu_reports) + if gpu_reports + else None + ) + + def summarize_shapes_for_seq_lens( + self, + config: LLMConfig, + shapes: List[ComputeShapeSummary], + batch_size: int = 1, + user_seq_len: int = 4096, + gpu_utilization: float = 0.95, + ) -> ShapeRecommendationReport: + """ + Generate a recommendation report for eligible deployment shapes by considering model memory consumption + and max model length. + + Parameters + ---------- + config : LLMConfig + The loaded model config. + shapes : List[ComputeShapeSummary] + All candidate deployment shapes. + batch_size : int + Batch size to evaluate. + user_seq_lens : Optional[List[int]] + Sequence lengths (contexts) provided by the user; if None, use defaults. + gpu_utilization : float + Utilization margin (e.g., 0.8, 0.9). + + Returns + ------- + ShapeRecommendationReport + """ + + recs = [] + + shape_reports = [] + + estimator = get_estimator( + config=config, batch_size=batch_size, seq_len=user_seq_len + ) + + logger.info(f"The {type(estimator)} will be used.") + + max_gpu_memory_size = 0 + + for shape in shapes: + shape_report = self.calc_gpu_report_per_shape( + estimator, shape, gpu_utilization + ) + + if shape_report: + shape_reports.append(shape_report) + + if ( + shape.gpu_specs + and shape.gpu_specs.gpu_memory_in_gbs > max_gpu_memory_size + ): + # reassign memory shape if we encounter a new larger shape + max_gpu_memory_shape = shape + max_gpu_memory_size = shape.gpu_specs.gpu_memory_in_gbs + + recs.append( + DeploymentShapeSummary( + batch_size=batch_size, + precision=config.quantization or config.weight_dtype, + gb_used_by_model=round(estimator.total_memory, 3), + max_seq_len=user_seq_len, + shape_reports=shape_reports, + ) + ) + + # we don't have any compatible shape recommendations but have shapes available in env + if shapes and not shape_reports: + # suggest the largest shape w/ actionable advice on how to make it fit + allowed = max_gpu_memory_size * gpu_utilization + advice = self.suggest_param_advice(estimator, allowed) + + troubleshoot = TroubleshootShapeSummary( + largest_shape=max_gpu_memory_shape.name, + gpu_memory_in_gb=max_gpu_memory_size, + gb_used_by_model=round(estimator.total_memory, 3), + max_seq_len=user_seq_len, + batch_size=batch_size, + precision=config.quantization or config.weight_dtype, + advice=advice, + ) + + return ShapeRecommendationReport( + recommendations=recs, troubleshoot=troubleshoot + ) diff --git a/ads/aqua/shaperecommend/shape_report.py b/ads/aqua/shaperecommend/shape_report.py new file mode 100644 index 000000000..939d468bc --- /dev/null +++ b/ads/aqua/shaperecommend/shape_report.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python +# Copyright (c) 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +from typing import List, Optional + +from pydantic import BaseModel, Field + + +class RequestRecommend(BaseModel): + model: str + max_model_len: Optional[int] = Field(4096, ) + +class GPUSummary(BaseModel): + gpu_count: int + gpu_memory_in_gb: int + limiting_factor: str + + +class ShapeSummary(BaseModel): + shape: str + gpu_reports: List[GPUSummary] + + +class DeploymentShapeSummary(BaseModel): + batch_size: int + max_seq_len: int + precision: str + gb_used_by_model: float + shape_reports: List[ShapeSummary] + + +class TroubleshootShapeSummary(BaseModel): + largest_shape: str + gpu_memory_in_gb: int + gb_used_by_model: float + batch_size: int + max_seq_len: int + precision: str + advice: str + + +class ShapeRecommendationReport(BaseModel): + """ + Contains shape fit recommendations and an optional troubleshooting summary. + """ + + # Each entry is: for this batch_size and max_seq_len, here are valid shapes + recommendations: List[DeploymentShapeSummary] = [] + troubleshoot: Optional[TroubleshootShapeSummary] = None +