-
Couldn't load subscription status.
- Fork 6.8k
[core][autoscaler][IPPR] Initial implementation for resizing pods in-place to the maximum configured by the user #55961
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -264,7 +264,12 @@ def get(self, path: str) -> Dict[str, Any]: | |
| pass | ||
|
|
||
| @abstractmethod | ||
| def patch(self, path: str, payload: List[Dict[str, Any]]) -> Dict[str, Any]: | ||
| def patch( | ||
| self, | ||
| path: str, | ||
| payload: List[Dict[str, Any]], | ||
| content_type: str = "application/json-patch+json", | ||
| ) -> Dict[str, Any]: | ||
| """Wrapper for REST PATCH of resource with proper headers.""" | ||
| pass | ||
|
|
||
|
|
@@ -316,12 +321,18 @@ def get(self, path: str) -> Dict[str, Any]: | |
| result.raise_for_status() | ||
| return result.json() | ||
|
|
||
| def patch(self, path: str, payload: List[Dict[str, Any]]) -> Dict[str, Any]: | ||
| def patch( | ||
| self, | ||
| path: str, | ||
| payload: List[Dict[str, Any]], | ||
| content_type: str = "application/json-patch+json", | ||
| ) -> Dict[str, Any]: | ||
| """Wrapper for REST PATCH of resource with proper headers | ||
|
|
||
| Args: | ||
| path: The part of the resource path that starts with the resource type. | ||
| payload: The JSON patch payload. | ||
| content_type: The content type of the merge strategy. | ||
|
|
||
| Returns: | ||
| The JSON response of the PATCH request. | ||
|
|
@@ -338,7 +349,7 @@ def patch(self, path: str, payload: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| result = requests.patch( | ||
| url, | ||
| json.dumps(payload), | ||
| headers={**headers, "Content-type": "application/json-patch+json"}, | ||
| headers={**headers, "Content-type": content_type}, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make content-type adjustable for different patch strategies. |
||
| timeout=KUBERAY_REQUEST_TIMEOUT_S, | ||
| verify=verify, | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,8 @@ | |
|
|
||
| import requests | ||
|
|
||
| from ray._raylet import GcsClient | ||
|
|
||
| # TODO(rickyx): We should eventually remove these imports | ||
| # when we deprecate the v1 kuberay node provider. | ||
| from ray.autoscaler._private.kuberay.node_provider import ( | ||
|
|
@@ -24,6 +26,9 @@ | |
| worker_delete_patch, | ||
| worker_replica_patch, | ||
| ) | ||
| from ray.autoscaler.v2.instance_manager.cloud_providers.kuberay.ippr_provider import ( | ||
| KubeRayIPPRProvider, | ||
| ) | ||
| from ray.autoscaler.v2.instance_manager.node_provider import ( | ||
| CloudInstance, | ||
| CloudInstanceId, | ||
|
|
@@ -33,7 +38,7 @@ | |
| NodeKind, | ||
| TerminateNodeError, | ||
| ) | ||
| from ray.autoscaler.v2.schema import NodeType | ||
| from ray.autoscaler.v2.schema import IPPRSpecs, IPPRStatus, NodeType | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -51,14 +56,19 @@ def __init__( | |
| self, | ||
| cluster_name: str, | ||
| provider_config: Dict[str, Any], | ||
| gcs_client: GcsClient, | ||
| k8s_api_client: Optional[IKubernetesHttpApiClient] = None, | ||
| ): | ||
| """ | ||
| Initializes a new KubeRayProvider. | ||
|
|
||
| Args: | ||
| cluster_name: The name of the RayCluster resource. | ||
| provider_config: The namespace of the RayCluster. | ||
| k8s_api_client: The client to the Kubernetes API server. | ||
| This could be used to mock the Kubernetes API server for testing. | ||
| provider_config: The configuration dictionary | ||
| for the RayCluster (e.g., namespace and provider-specific settings). | ||
| gcs_client: The client to the GCS server. | ||
| k8s_api_client: The client to the Kubernetes | ||
| API server. This can be used to mock the Kubernetes API server for testing. | ||
rueian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| self._cluster_name = cluster_name | ||
| self._namespace = provider_config["namespace"] | ||
|
|
@@ -75,6 +85,9 @@ def __init__( | |
| # Below are states that are fetched from the Kubernetes API server. | ||
| self._ray_cluster = None | ||
| self._cached_instances: Dict[CloudInstanceId, CloudInstance] | ||
| self._ippr_provider = KubeRayIPPRProvider( | ||
| gcs_client=gcs_client, k8s_api_client=self._k8s_api_client | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The KubeRayIPPRProvider needs a gcs_client to query the port and the address of a Raylet, and it also needs a k8s_api_client to patch pods. |
||
| ) | ||
rueian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| @dataclass | ||
| class ScaleRequest: | ||
|
|
@@ -183,6 +196,31 @@ def poll_errors(self) -> List[CloudInstanceProviderError]: | |
| self._terminate_errors_queue = [] | ||
| return errors | ||
|
|
||
| def get_ippr_specs(self) -> IPPRSpecs: | ||
| """Return the cached, validated IPPR specs for the cluster. | ||
|
|
||
| The IPPR specs are refreshed during the provider's periodic sync with the | ||
| API server by reading the RayCluster annotation and validating it against | ||
| the IPPR schema. | ||
| """ | ||
| return self._ippr_provider.get_ippr_specs() | ||
|
|
||
| def get_ippr_statuses(self) -> Dict[str, IPPRStatus]: | ||
| """Return the latest per-pod IPPR statuses keyed by pod name. | ||
|
|
||
| These statuses are refreshed from the current pod list during the provider's | ||
| periodic sync with the API server. | ||
| """ | ||
| return self._ippr_provider.get_ippr_statuses() | ||
|
|
||
| def do_ippr_requests(self, resizes: List[IPPRStatus]) -> None: | ||
| """Execute IPPR resize requests via the underlying IPPR provider. | ||
|
|
||
| Args: | ||
| resizes: The list of per-pod IPPR actions produced by the scheduler. | ||
| """ | ||
| self._ippr_provider.do_ippr_requests(resizes) | ||
|
|
||
| ############################ | ||
| # Private | ||
| ############################ | ||
|
|
@@ -416,7 +454,9 @@ def _add_terminate_errors( | |
| def _sync_with_api_server(self) -> None: | ||
| """Fetches the RayCluster resource from the Kubernetes API server.""" | ||
| self._ray_cluster = self._get(f"rayclusters/{self._cluster_name}") | ||
| self._ippr_provider.validate_and_set_ippr_specs(self._ray_cluster) | ||
| self._cached_instances = self._fetch_instances() | ||
| self._ippr_provider.sync_with_raylets() | ||
|
|
||
| @property | ||
| def ray_cluster(self) -> Dict[str, Any]: | ||
|
|
@@ -522,6 +562,9 @@ def _fetch_instances(self) -> Dict[CloudInstanceId, CloudInstance]: | |
| cloud_instance = self._cloud_instance_from_pod(pod) | ||
| if cloud_instance: | ||
| cloud_instances[pod_name] = cloud_instance | ||
|
|
||
| self._ippr_provider.sync_ippr_status_from_pods(pod_list["items"]) | ||
|
|
||
| return cloud_instances | ||
|
|
||
| @staticmethod | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix lint.