Skip to content

Commit 18a92c4

Browse files
committed
inital code for GPU Shape Recommendator
1 parent 9c1095e commit 18a92c4

File tree

11 files changed

+1282
-5
lines changed

11 files changed

+1282
-5
lines changed

ads/aqua/cli.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ads.aqua.model import AquaModelApp
1616
from ads.aqua.modeldeployment import AquaDeploymentApp
1717
from ads.aqua.verify_policies import AquaVerifyPoliciesApp
18+
from ads.aqua.shaperecommend.recommend import AquaRecommendApp
1819
from ads.common.utils import LOG_LEVELS
1920

2021

@@ -31,6 +32,7 @@ class AquaCommand:
3132
deployment = AquaDeploymentApp
3233
evaluation = AquaEvaluationApp
3334
verify_policies = AquaVerifyPoliciesApp
35+
recommend = AquaRecommendApp
3436

3537
def __init__(
3638
self,
@@ -96,18 +98,20 @@ def _validate_value(flag, value):
9698
"If you intend to chain a function call to the result, please separate the "
9799
"flag and the subsequent function call with separator `-`."
98100
)
99-
101+
100102
@staticmethod
101103
def install():
102104
"""Install ADS Aqua Extension from wheel file. Set enviroment variable `AQUA_EXTENSTION_PATH` to change the wheel file path.
103105
104-
Return
106+
Return
105107
------
106108
int:
107109
Installatation status.
108110
"""
109111
import subprocess
110112

111-
wheel_file_path = os.environ.get("AQUA_EXTENSTION_PATH", "/ads/extension/adsjupyterlab_aqua_extension*.whl")
112-
status = subprocess.run(f"pip install {wheel_file_path}",shell=True)
113-
return status.check_returncode
113+
wheel_file_path = os.environ.get(
114+
"AQUA_EXTENSTION_PATH", "/ads/extension/adsjupyterlab_aqua_extension*.whl"
115+
)
116+
status = subprocess.run(f"pip install {wheel_file_path}", shell=True)
117+
return status.check_returncode

ads/aqua/common/entities.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,17 @@ class Config:
4646
arbitrary_types_allowed = True
4747
protected_namespaces = ()
4848

49+
class ComputeRank(Serializable):
50+
"""
51+
Represents the cost and performance ranking for a compute shape.
52+
"""
53+
cost: int = Field(
54+
None, description="The relative rank of the cost of the shape. Range is [10 (cost-effective), 100 (most-expensive)]"
55+
)
56+
57+
performance: int = Field(
58+
None, description="The relative rank of the performance of the shape. Range is [10 (lower performance), 110 (highest performance)]"
59+
)
4960

5061
class GPUSpecs(Serializable):
5162
"""
@@ -61,6 +72,12 @@ class GPUSpecs(Serializable):
6172
gpu_type: Optional[str] = Field(
6273
default=None, description="The type of GPU (e.g., 'V100, A100, H100')."
6374
)
75+
quantization: Optional[List[str]] = Field(
76+
default_factory=list, description="The quantization format supported by shape. (ex. bitsandbytes, fp8, etc.)"
77+
)
78+
ranking: Optional[ComputeRank] = Field(
79+
None, description="The relative rank of the cost and performance of the shape."
80+
)
6481

6582

6683
class GPUShapesIndex(Serializable):

ads/aqua/extension/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313
from ads.aqua.extension.evaluation_handler import __handlers__ as __eval_handlers__
1414
from ads.aqua.extension.finetune_handler import __handlers__ as __finetune_handlers__
15+
from ads.aqua.extension.gpu_recommend_handler import __handlers__ as __gpu_handlers__
1516
from ads.aqua.extension.model_handler import __handlers__ as __model_handlers__
1617
from ads.aqua.extension.ui_handler import __handlers__ as __ui_handlers__
1718
from ads.aqua.extension.ui_websocket_handler import __handlers__ as __ws_handlers__
@@ -24,6 +25,7 @@
2425
+ __ui_handlers__
2526
+ __eval_handlers__
2627
+ __ws_handlers__
28+
+ __gpu_handlers__
2729
)
2830

2931

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
2+
from tornado.web import HTTPError
3+
4+
from ads.aqua.common.decorator import handle_exceptions
5+
from ads.aqua.extension.base_handler import AquaAPIhandler
6+
from ads.aqua.extension.errors import Errors
7+
from ads.aqua.shaperecommend.recommend import AquaRecommendApp
8+
from ads.config import COMPARTMENT_OCID
9+
10+
11+
class AquaRecommendHandler(AquaAPIhandler):
12+
"""
13+
Handler for Aqua GPU Recommendation REST APIs.
14+
15+
Methods
16+
-------
17+
get(self, id: Union[str, List[str]])
18+
Retrieves a list of AQUA deployments or model info or logs by ID.
19+
post(self, *args, **kwargs)
20+
Obtains the eligible compute shapes that would fit the specifed model, context length, model weights, and quantization level.
21+
22+
Raises
23+
------
24+
HTTPError: For various failure scenarios such as invalid input format, missing data, etc.
25+
"""
26+
27+
@handle_exceptions
28+
def post(self, *args, **kwargs): # noqa: ARG002
29+
"""
30+
Lists the eligible GPU compute shapes for the specifed model.
31+
32+
Returns
33+
-------
34+
List[ComputeShapeSummary]:
35+
The list of the model deployment shapes.
36+
"""
37+
try:
38+
input_data = self.get_json_body()
39+
# input_data["compartment_id"] = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
40+
except Exception as ex:
41+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
42+
43+
if not input_data:
44+
raise HTTPError(400, Errors.NO_INPUT_DATA)
45+
46+
self.finish(AquaRecommendApp().which_gpu(**input_data))
47+
48+
__handlers__ = [
49+
("gpu-shape-recommendation/?([^/]*)", AquaRecommendHandler),
50+
]

ads/aqua/resources/shapes.json

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
{
2+
"shapes": {
3+
"BM.GPU.H200.8": {
4+
"gpu_count": 8,
5+
"gpu_memory_in_gbs": 1128,
6+
"gpu_type": "H200",
7+
"quantization": ["awq", "gptq", "marlin", "fp8", "int8", "bitblas", "aqlm", "bitsandbytes", "deepspeedfp", "gguf"],
8+
"ranking": {
9+
"cost": 100,
10+
"performance": 110
11+
}
12+
},
13+
"BM.GPU.H100.8": {
14+
"gpu_count": 8,
15+
"gpu_memory_in_gbs": 640,
16+
"gpu_type": "H100",
17+
"quantization": ["awq", "gptq", "marlin", "fp8", "int8", "bitblas", "aqlm", "bitsandbytes", "deepspeedfp", "gguf"],
18+
"ranking": {
19+
"cost": 100,
20+
"performance": 100
21+
}
22+
},
23+
"BM.GPU.MI300X.8": {
24+
"gpu_count": 8,
25+
"gpu_memory_in_gbs": 1536,
26+
"gpu_type": "MI300X",
27+
"quantization": ["fp8", "gguf"],
28+
"ranking": {
29+
"cost": 90,
30+
"performance": 90
31+
}
32+
},
33+
"BM.GPU.A100-V2.8": {
34+
"gpu_count": 8,
35+
"gpu_memory_in_gbs": 640,
36+
"gpu_type": "A100",
37+
"quantization": ["awq", "gptq", "marlin", "int8", "bitblas", "aqlm", "bitsandbytes", "deepspeedfp", "gguf"],
38+
"ranking": {
39+
"cost": 80,
40+
"performance": 70
41+
}
42+
},
43+
"BM.GPU.B4.8": {
44+
"gpu_count": 8,
45+
"gpu_memory_in_gbs": 320,
46+
"gpu_type": "A100",
47+
"quantization": ["awq", "gptq", "marlin", "int8", "bitblas", "aqlm", "bitsandbytes", "deepspeedfp", "gguf"],
48+
"ranking": {
49+
"cost": 70,
50+
"performance": 60
51+
}
52+
},
53+
"BM.GPU.L40S-NC.4": {
54+
"gpu_count": 4,
55+
"gpu_memory_in_gbs": 192,
56+
"gpu_type": "L40S",
57+
"quantization": ["awq", "gptq", "marlin", "fp8", "int8", "bitblas", "aqlm", "bitsandbytes", "deepspeedfp", "gguf"],
58+
"ranking": {
59+
"cost": 60,
60+
"performance": 80
61+
}
62+
},
63+
"BM.GPU.L40S.4": {
64+
"gpu_count": 4,
65+
"gpu_memory_in_gbs": 192,
66+
"gpu_type": "L40S",
67+
"quantization": ["awq", "gptq", "marlin", "fp8", "int8", "bitblas", "aqlm", "bitsandbytes", "deepspeedfp", "gguf"],
68+
"ranking": {
69+
"cost": 60,
70+
"performance": 80
71+
}
72+
},
73+
"VM.GPU.A10.1": {
74+
"gpu_count": 1,
75+
"gpu_memory_in_gbs": 24,
76+
"gpu_type": "A10",
77+
"quantization": ["awq", "gptq", "marlin", "int8", "bitblas", "aqlm", "bitsandbytes", "deepspeedfp", "gguf"],
78+
"ranking" : {
79+
"cost": 20,
80+
"performance": 30
81+
}
82+
},
83+
"VM.GPU.A10.2": {
84+
"gpu_count": 2,
85+
"gpu_memory_in_gbs": 48,
86+
"gpu_type": "A10",
87+
"quantization": ["awq", "gptq", "marlin", "int8", "bitblas", "aqlm", "bitsandbytes", "deepspeedfp", "gguf"],
88+
"ranking" : {
89+
"cost": 40,
90+
"performance": 40
91+
}
92+
},
93+
"BM.GPU.A10.4": {
94+
"gpu_count": 4,
95+
"gpu_memory_in_gbs": 96,
96+
"gpu_type": "A10",
97+
"quantization": ["awq", "gptq", "marlin", "int8", "bitblas", "aqlm", "bitsandbytes", "deepspeedfp", "gguf"],
98+
"ranking" : {
99+
"cost": 50,
100+
"performance": 50
101+
}
102+
},
103+
"BM.GPU2.2": {
104+
"gpu_count": 2,
105+
"gpu_memory_in_gbs": 32,
106+
"gpu_type": "P100",
107+
"quantization": ["fp16"],
108+
"ranking": {
109+
"cost": 30,
110+
"performance": 20
111+
}
112+
},
113+
"VM.GPU2.1": {
114+
"gpu_count": 1,
115+
"gpu_memory_in_gbs": 16,
116+
"gpu_type": "P100",
117+
"quantization": ["fp16"],
118+
"ranking": {
119+
"cost": 10,
120+
"performance": 10
121+
}
122+
}
123+
}
124+
}

ads/aqua/shaperecommend/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2025 Oracle and/or its affiliates.
3+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4+
from ads.aqua.shaperecommend.recommend import AquaRecommendApp
5+
6+
__all__ = ["AquaRecommendApp"]

ads/aqua/shaperecommend/constants.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
3+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4+
5+
"""
6+
aqua.shaperecommend.constants
7+
~~~~~~~~~~~~~~
8+
9+
This module contains constants used in Aqua GPU Recommendation for Models.
10+
11+
LLAMA_REQUIRED_FIELDS refer to fields necessary for calculating model memory for GQA Architecture Models
12+
13+
MOE_REQUIRED_FIELDS refer to fields necessary for Mixture of Experts (MoE) Architecture Models
14+
15+
NEXT_QUANT suggests the next quantization level based on the current quantization (if applied) or the model weights (if no quantization yet)
16+
"""
17+
LLAMA_REQUIRED_FIELDS = [
18+
"num_hidden_layers", "hidden_size", "num_attention_heads",
19+
"num_key_value_heads", "head_dim", "intermediate_size", "vocab_size"
20+
]
21+
22+
MOE_REQUIRED_FIELDS = LLAMA_REQUIRED_FIELDS + [
23+
"num_local_experts", "intermediate_size"
24+
]
25+
26+
NEXT_QUANT = {
27+
"float32": ["4bit", "8bit"], # bits and bytes does not support bfloat16, pytorch responsibility
28+
"bfloat16": ["4bit", "8bit"],
29+
"float16": ["4bit", "8bit"],
30+
"int8": ["4bit"],
31+
"fp8": ["4bit", "8bit"],
32+
"8bit": ["4bit"],
33+
"int4": ["No smaller quantization available"],
34+
"4bit": ["No smaller quantization available"]
35+
}
36+
37+
#TODO:
38+
SHAPES_METADATA = "/Users/elizjo/tmp/accelerated-data-science/ads/aqua/resources/shapes.json"
39+
40+
TEXT_MODEL = "text-generation"
41+
42+
QUANT_MAPPING = {
43+
"float32": 4,
44+
"bfloat16": 2,
45+
"float16": 2,
46+
"fp16": 2,
47+
"half": 2,
48+
"int8": 1,
49+
"fp8": 1,
50+
"8bit": 1,
51+
"4bit": 0.5,
52+
"int4": 0.5,
53+
}
54+
55+

0 commit comments

Comments
 (0)