|
8 | 8 | import shlex
|
9 | 9 | import threading
|
10 | 10 | from datetime import datetime, timedelta
|
11 |
| -from typing import Dict, List, Optional |
| 11 | +from typing import Dict, List, Optional, Union |
12 | 12 |
|
13 | 13 | from cachetools import TTLCache, cached
|
14 | 14 | from oci.data_science.models import ModelDeploymentShapeSummary
|
15 | 15 | from pydantic import ValidationError
|
| 16 | +from rich.table import Table |
16 | 17 |
|
17 | 18 | from ads.aqua.app import AquaApp, logger
|
18 | 19 | from ads.aqua.common.entities import (
|
|
44 | 45 | AQUA_MODEL_TYPE_SERVICE,
|
45 | 46 | AQUA_MULTI_MODEL_CONFIG,
|
46 | 47 | MODEL_BY_REFERENCE_OSS_PATH_KEY,
|
| 48 | + MODEL_GROUP, |
47 | 49 | MODEL_NAME_DELIMITER,
|
| 50 | + SINGLE_MODEL_FLEX, |
48 | 51 | UNKNOWN_DICT,
|
| 52 | + UNKNOWN_ENUM_VALUE, |
49 | 53 | )
|
50 | 54 | from ads.aqua.data import AquaResourceIdentifier
|
51 | 55 | from ads.aqua.model import AquaModelApp
|
|
64 | 68 | ModelDeploymentConfigSummary,
|
65 | 69 | MultiModelDeploymentConfigLoader,
|
66 | 70 | )
|
67 |
| -from ads.aqua.modeldeployment.constants import DEFAULT_POLL_INTERVAL, DEFAULT_WAIT_TIME |
| 71 | +from ads.aqua.modeldeployment.constants import ( |
| 72 | + DEFAULT_POLL_INTERVAL, |
| 73 | + DEFAULT_WAIT_TIME, |
| 74 | +) |
68 | 75 | from ads.aqua.modeldeployment.entities import (
|
69 | 76 | AquaDeployment,
|
70 | 77 | AquaDeploymentDetail,
|
71 | 78 | ConfigValidationError,
|
72 | 79 | CreateModelDeploymentDetails,
|
73 | 80 | )
|
74 | 81 | from ads.aqua.modeldeployment.model_group_config import ModelGroupConfig
|
| 82 | +from ads.aqua.shaperecommend.recommend import AquaShapeRecommend |
| 83 | +from ads.aqua.shaperecommend.shape_report import ( |
| 84 | + RequestRecommend, |
| 85 | + ShapeRecommendationReport, |
| 86 | +) |
75 | 87 | from ads.common.object_storage_details import ObjectStorageDetails
|
76 | 88 | from ads.common.utils import UNKNOWN, get_log_links
|
77 | 89 | from ads.common.work_request import DataScienceWorkRequest
|
@@ -864,21 +876,26 @@ def list(self, **kwargs) -> List["AquaDeployment"]:
|
864 | 876 |
|
865 | 877 | if oci_aqua:
|
866 | 878 | # skipping the AQUA model deployments that are created from model group
|
867 |
| - # TODO: remove this checker after AQUA deployment is integrated with model group |
868 |
| - aqua_model_id = model_deployment.freeform_tags.get( |
869 |
| - Tags.AQUA_MODEL_ID_TAG, UNKNOWN |
870 |
| - ) |
871 | 879 | if (
|
872 |
| - "datasciencemodelgroup" in aqua_model_id |
873 |
| - or model_deployment.model_deployment_configuration_details.deployment_type |
874 |
| - == "UNKNOWN_ENUM_VALUE" |
| 880 | + model_deployment.model_deployment_configuration_details.deployment_type |
| 881 | + in [UNKNOWN_ENUM_VALUE, MODEL_GROUP, SINGLE_MODEL_FLEX] |
875 | 882 | ):
|
876 | 883 | continue
|
877 |
| - results.append( |
878 |
| - AquaDeployment.from_oci_model_deployment( |
879 |
| - model_deployment, self.region |
| 884 | + try: |
| 885 | + results.append( |
| 886 | + AquaDeployment.from_oci_model_deployment( |
| 887 | + model_deployment, self.region |
| 888 | + ) |
880 | 889 | )
|
881 |
| - ) |
| 890 | + except Exception as e: |
| 891 | + logger.error( |
| 892 | + f"There was an issue processing the list of model deployments . Error: {str(e)}", |
| 893 | + exc_info=True, |
| 894 | + ) |
| 895 | + raise AquaRuntimeError( |
| 896 | + f"There was an issue processing the list of model deployments . Error: {str(e)}" |
| 897 | + ) from e |
| 898 | + |
882 | 899 | # log telemetry if MD is in active or failed state
|
883 | 900 | deployment_id = model_deployment.id
|
884 | 901 | state = model_deployment.lifecycle_state.upper()
|
@@ -1249,6 +1266,50 @@ def validate_deployment_params(
|
1249 | 1266 | )
|
1250 | 1267 | return {"valid": True}
|
1251 | 1268 |
|
| 1269 | + def recommend_shape(self, **kwargs) -> Union[Table, ShapeRecommendationReport]: |
| 1270 | + """ |
| 1271 | + For the CLI (set generate_table = True), generates the table (in rich diff) with valid |
| 1272 | + GPU deployment shapes for the provided model and configuration. |
| 1273 | +
|
| 1274 | + For the API (set generate_table = False), generates the JSON with valid |
| 1275 | + GPU deployment shapes for the provided model and configuration. |
| 1276 | +
|
| 1277 | + Validates if recommendations are generated, calls method to construct the rich diff |
| 1278 | + table with the recommendation data. |
| 1279 | +
|
| 1280 | + Parameters |
| 1281 | + ---------- |
| 1282 | + model_ocid : str |
| 1283 | + OCID of the model to recommend feasible compute shapes. |
| 1284 | +
|
| 1285 | + Returns |
| 1286 | + ------- |
| 1287 | + Table (generate_table = True) |
| 1288 | + A table format for the recommendation report with compatible deployment shapes |
| 1289 | + or troubleshooting info citing the largest shapes if no shape is suitable. |
| 1290 | +
|
| 1291 | + ShapeRecommendationReport (generate_table = False) |
| 1292 | + A recommendation report with compatible deployment shapes, or troubleshooting info |
| 1293 | + citing the largest shapes if no shape is suitable. |
| 1294 | +
|
| 1295 | + Raises |
| 1296 | + ------ |
| 1297 | + AquaValueError |
| 1298 | + If model type is unsupported by tool (no recommendation report generated) |
| 1299 | + """ |
| 1300 | + try: |
| 1301 | + request = RequestRecommend(**kwargs) |
| 1302 | + except ValidationError as e: |
| 1303 | + custom_error = build_pydantic_error_message(e) |
| 1304 | + raise AquaValueError( # noqa: B904 |
| 1305 | + f"Failed to request shape recommendation due to invalid input parameters: {custom_error}" |
| 1306 | + ) |
| 1307 | + |
| 1308 | + shape_recommend = AquaShapeRecommend() |
| 1309 | + shape_recommend_report = shape_recommend.which_shapes(request) |
| 1310 | + |
| 1311 | + return shape_recommend_report |
| 1312 | + |
1252 | 1313 | @telemetry(entry_point="plugin=deployment&action=list_shapes", name="aqua")
|
1253 | 1314 | @cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=5), timer=datetime.now))
|
1254 | 1315 | def list_shapes(self, **kwargs) -> List[ComputeShapeSummary]:
|
|
0 commit comments