-
Notifications
You must be signed in to change notification settings - Fork 35
add python accelerator HAL for multi-vendor backends #86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
theap06
wants to merge
24
commits into
facebookresearch:main
from
theap06:feat/accelerator-hal-upstream-clean
Closed
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
f312b33
fixed the style and the claude feedback
theap06 31ace9b
Fix HAL nox setup and health_checks ctx obj handling
theap06 2eb1cff
Skip backend probe for injected nvml monitor objects
theap06 778ca5a
Initialize health_checks ctx obj when missing
theap06 14bb10b
Preserve health_checks obj while storing backend
theap06 b2ea528
adding myself to the README fiel
theap06 08045b9
fix: align health_checks ctx.obj initialization with gcm CLI pattern
theap06 7301505
fix(ci): mkdir -p venv bin before copying Rust cargo binaries
theap06 b84a29e
fix(ci): create venv in build_deb if cache miss, source cargo env bef…
theap06 e870583
fix(ci): create venv on cache miss in common-setup action
theap06 e1c318b
revert(ci): remove pip install fallbacks, keep only cp fix for build_deb
theap06 0179f52
fix(health_checks): avoid overriding click obj for subcommands
theap06 46af65b
fixed the fragility in the conditional for healthchecks function base…
theap06 9d3f8b3
fixed the error handling for healthchecks
theap06 8755085
fixed the error handling for healthchecks
theap06 399e276
preserved the non-dict objects and applied the defensive init of dict
theap06 6a78c2f
fixed the health_checks() for the health_checks.py. the previous ci w…
theap06 409af19
fixed the health_checks function in health_checks.py by running pass …
theap06 333dcc1
Addressing PR review comments: removing telemetry files and unnecessa…
theap06 ccc9129
Restore Rust installation step from main
theap06 effaef8
Refactor legacy tools to use Accelerator HAL Adapter
theap06 c6b0855
Fix flaky integration test in test_gcm.py and update test plan
theap06 9ef8a60
Fix mock patching in test_check_nvidia_smi_hal_parity.py
theap06 b7298e1
Merge branch 'main' into feat/accelerator-hal-upstream-clean
theap06 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| # Test Plan: Accelerator HAL Migration | ||
|
|
||
| This document outlines the test plan to verify that the migration to the Accelerator HAL (Hardware Abstraction Layer) preserves existing functionality for NVML-based monitoring and health checks. | ||
|
|
||
| ## Objective | ||
|
|
||
| Ensure that all existing NVML paths (`nvml_monitor` and `check_nvidia_smi`) continue to function identically after being refactored to use the `AcceleratorManager` and `NVMLBackend` interface. | ||
|
|
||
| ## Coverage Areas | ||
|
|
||
| 1. **Metric Collection (`nvml_monitor`)**: Verifying GPU metrics (utilization, memory, power, temperature, clocks, ECC) are collected correctly. | ||
| 2. **Health Checks (`check_nvidia_smi`)**: Verifying GPU presence, running processes, and error detection. | ||
| 3. **Error Handling**: Ensuring that backend unavailability or device errors are handled gracefully and logged appropriately. | ||
|
|
||
| ## Test Cases | ||
|
|
||
| ### 1. Unit Tests | ||
|
|
||
| Run existing unit tests to verify no regressions in logic. | ||
|
|
||
| ```bash | ||
| pytest gcm/tests/test_accelerator_hal.py | ||
| pytest gcm/tests/health_checks_tests/test_check_nvidia_smi.py | ||
| pytest gcm/tests/test_nvml_monitor.py | ||
| ``` | ||
|
|
||
| ### 2. Manual Verification (Stubbed) | ||
|
|
||
| Since we cannot run on actual GPU hardware in this environment, we rely on the stubbed NVML library used in tests. | ||
|
|
||
| #### A. NVML Monitor | ||
|
|
||
| **Refactored Logic:** | ||
| `nvml_monitor` now instantiates `AcceleratorManager`, probes backends, and uses `AcceleratorTelemetryAdapter` to interact with device handles provided by `NVMLBackend`. | ||
|
|
||
| **Verification Step:** | ||
| Verify that `nvml_monitor.py` correctly fetches device count and metrics via the adapter. The adapter ensures that underlying `pynvml` calls are routed through the `AcceleratorManager`'s backend instance. | ||
|
|
||
| #### B. Health Checks | ||
|
|
||
| **Refactored Logic:** | ||
| `check_nvidia_smi` now instantiates `AcceleratorManager` and uses `AcceleratorTelemetryAdapter` to perform checks. | ||
|
|
||
| **Verification Step:** | ||
| Verify that `check_nvidia_smi.py` correctly detects GPU count and running processes via the adapter. | ||
| Also verified `gcm/tests/test_gcm.py::test_health_checks_backend_nvml_full_run` which exercises the full health check loop. Updated test to handle potential extra output from check execution. | ||
|
|
||
| ## Refactoring Status | ||
|
|
||
| - **`gcm/accelerator`**: Core HAL interfaces and NVML backend implementation are complete. | ||
| - **`nvml_monitor.py`**: Refactored to use `AcceleratorManager` via `AcceleratorTelemetryAdapter`. | ||
| - **`check_nvidia_smi.py`**: Refactored to use `AcceleratorManager` via `AcceleratorTelemetryAdapter`. | ||
| - **Legacy Shim**: Added `gcm/monitoring/accelerator_adapter.py` to bridge `DeviceTelemetryClient` calls to the HAL backend, ensuring 100% backward compatibility for methods not yet fully exposed in `MetricSet` (e.g., specific ECC error counts). | ||
|
|
||
| ## Rollout Strategy | ||
|
|
||
| 1. **Phase 1 (Current PR)**: Introduce HAL, migrate all NVML usage to `AcceleratorManager` via adapter shim. | ||
| 2. **Phase 2 (Future)**: Update `nvml_monitor` logic to use `AcceleratorManager.read_metrics()` directly, removing dependency on `DeviceTelemetryClient` interface once `MetricSet` is expanded to cover all needs. | ||
|
|
||
| This incremental approach ensures that the new architecture is active immediately while minimizing risk to existing business logic. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| from gcm.accelerator.backend import ( | ||
| AcceleratorBackend, | ||
| BackendName, | ||
| DeviceHandle, | ||
| ProbeResult, | ||
| ) | ||
| from gcm.accelerator.errors import ( | ||
| AcceleratorError, | ||
| BackendUnavailableError, | ||
| UnsupportedOperationError, | ||
| ) | ||
| from gcm.accelerator.manager import AcceleratorManager | ||
| from gcm.accelerator.metrics import MetricRequest, MetricSet | ||
| from gcm.accelerator.registry import default_backend_factories | ||
|
|
||
| __all__ = [ | ||
| "AcceleratorBackend", | ||
| "AcceleratorError", | ||
| "AcceleratorManager", | ||
| "BackendName", | ||
| "BackendUnavailableError", | ||
| "DeviceHandle", | ||
| "MetricRequest", | ||
| "MetricSet", | ||
| "ProbeResult", | ||
| "UnsupportedOperationError", | ||
| "default_backend_factories", | ||
| ] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| from dataclasses import dataclass, field | ||
| from datetime import datetime, timezone | ||
| from enum import Enum | ||
| from typing import Callable, List, Protocol | ||
|
|
||
| from gcm.accelerator.metrics import MetricRequest, MetricSet | ||
|
|
||
|
|
||
| class BackendName(str, Enum): | ||
| NVML = "nvml" | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class ProbeResult: | ||
| backend: BackendName | ||
| healthy: bool | ||
| reason: str | ||
| library_path: str | None = None | ||
| driver_version: str | None = None | ||
| probed_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class DeviceHandle: | ||
| backend: BackendName | ||
| id: str | ||
| vendor: str | ||
| model: str | None = None | ||
| bus_id: str | None = None | ||
| serial: str | None = None | ||
|
|
||
|
|
||
| class AcceleratorBackend(Protocol): | ||
| def name(self) -> BackendName: ... | ||
|
|
||
| def probe(self) -> ProbeResult: ... | ||
|
|
||
| def enumerate_devices(self) -> List[DeviceHandle]: ... | ||
|
|
||
| def read_metrics( | ||
| self, device: DeviceHandle, request: MetricRequest | ||
| ) -> MetricSet: ... | ||
|
|
||
| def close(self) -> None: ... | ||
|
|
||
|
|
||
| BackendFactory = Callable[[], AcceleratorBackend] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,187 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| from dataclasses import dataclass, field | ||
| from datetime import datetime, timezone | ||
| from typing import Any, Callable, Optional, TypeVar | ||
|
|
||
| from gcm.accelerator.backend import BackendName, DeviceHandle, ProbeResult | ||
| from gcm.accelerator.errors import BackendUnavailableError, UnsupportedOperationError | ||
| from gcm.accelerator.metrics import MetricRequest, MetricSet | ||
| from gcm.accelerator.probe import find_and_load_library | ||
| from gcm.monitoring.device_telemetry_client import ( | ||
| DeviceTelemetryClient, | ||
| DeviceTelemetryException, | ||
| ) | ||
| from gcm.monitoring.utils.error import safe_call | ||
| from gcm.schemas.gpu.application_clock import ApplicationClockInfo | ||
|
|
||
| from gcm.schemas.gpu.memory import GPUMemory | ||
| from gcm.schemas.gpu.utilization import GPUUtilization | ||
|
|
||
| _NAMES = ["nvidia-ml"] | ||
| _PATHS = [ | ||
| "/usr/lib/x86_64-linux-gnu/libnvidia-ml.so.1", | ||
| "/usr/lib64/libnvidia-ml.so.1", | ||
| "/usr/lib/libnvidia-ml.so.1", | ||
| ] | ||
|
|
||
| _T = TypeVar("_T") | ||
|
|
||
|
|
||
| def _default_nvml_client_factory() -> DeviceTelemetryClient: | ||
| # Keep the import lazy so this package can still be imported in | ||
| # environments where pynvml is unavailable. | ||
| from gcm.monitoring.device_telemetry_nvml import NVMLDeviceTelemetryClient | ||
|
|
||
| return NVMLDeviceTelemetryClient() | ||
|
|
||
|
|
||
| @dataclass | ||
| class NVMLBackend: | ||
| telemetry_client_factory: Callable[[], DeviceTelemetryClient] = ( | ||
| _default_nvml_client_factory | ||
| ) | ||
| _client: Optional[DeviceTelemetryClient] = field( | ||
| default=None, init=False, repr=False | ||
| ) | ||
| _handles: dict[str, Any] = field(default_factory=dict, init=False, repr=False) | ||
|
|
||
| def name(self) -> BackendName: | ||
| return BackendName.NVML | ||
|
|
||
| def _ensure_client(self) -> DeviceTelemetryClient: | ||
| if self._client is None: | ||
| self._client = self.telemetry_client_factory() | ||
| return self._client | ||
|
|
||
| def probe(self) -> ProbeResult: | ||
| path = find_and_load_library(_NAMES, _PATHS) | ||
| if path is None: | ||
| raise BackendUnavailableError("NVML shared library not found") | ||
| client = self._ensure_client() | ||
| try: | ||
| client.get_device_count() | ||
| except DeviceTelemetryException as e: | ||
| raise BackendUnavailableError("NVML initialization failed") from e | ||
| return ProbeResult( | ||
| backend=self.name(), | ||
| healthy=True, | ||
| reason="ready", | ||
| library_path=path, | ||
| probed_at=datetime.now(timezone.utc), | ||
| ) | ||
|
|
||
| def enumerate_devices(self) -> list[DeviceHandle]: | ||
| client = self._ensure_client() | ||
| try: | ||
| device_count = client.get_device_count() | ||
| devices: list[DeviceHandle] = [] | ||
| for index in range(device_count): | ||
| model: Optional[str] = None | ||
|
|
||
| # Check cache first or fetch handle | ||
| dev_id = str(index) | ||
| if dev_id in self._handles: | ||
| handle = self._handles[dev_id] | ||
| else: | ||
| handle = client.get_device_by_index(index) | ||
| self._handles[dev_id] = handle | ||
|
|
||
| model_getter = getattr(handle, "get_name", None) | ||
| if callable(model_getter): | ||
| maybe_model = self._safe_call(model_getter) | ||
| if isinstance(maybe_model, str): | ||
| model = maybe_model | ||
| devices.append( | ||
| DeviceHandle( | ||
| backend=self.name(), | ||
| id=dev_id, | ||
| vendor="nvidia", | ||
| model=model, | ||
| ) | ||
| ) | ||
| return devices | ||
| except DeviceTelemetryException as e: | ||
| raise UnsupportedOperationError("NVML enumerate_devices failed") from e | ||
|
|
||
| @staticmethod | ||
| def _safe_call(func: Callable[[], _T]) -> _T | None: | ||
| return safe_call(func, DeviceTelemetryException, logger_name=__name__) | ||
|
|
||
| def read_metrics(self, device: DeviceHandle, _request: MetricRequest) -> MetricSet: | ||
| # TODO: Wire MetricRequest.include_process_info once process telemetry | ||
| # is available through HAL MetricSet. | ||
| client = self._ensure_client() | ||
| try: | ||
| if device.id in self._handles: | ||
| handle = self._handles[device.id] | ||
| else: | ||
| index = int(device.id) | ||
| handle = client.get_device_by_index(index) | ||
| self._handles[device.id] = handle | ||
| except (ValueError, DeviceTelemetryException) as e: | ||
| raise UnsupportedOperationError( | ||
| f"invalid NVML device id: {device.id}" | ||
| ) from e | ||
|
|
||
| utilization: GPUUtilization | None = self._safe_call( | ||
| handle.get_utilization_rates | ||
| ) | ||
| memory: GPUMemory | None = self._safe_call(handle.get_memory_info) | ||
| temperature: int | None = self._safe_call(handle.get_temperature) | ||
| power_usage: int | None = self._safe_call(handle.get_power_usage) | ||
| power_limit: int | None = self._safe_call(handle.get_enforced_power_limit) | ||
| clocks: ApplicationClockInfo | None = self._safe_call(handle.get_clock_freq) | ||
| ecc_corrected: int | None = self._safe_call( | ||
| handle.get_ecc_corrected_volatile_total | ||
| ) | ||
| ecc_uncorrected: int | None = self._safe_call( | ||
| handle.get_ecc_uncorrected_volatile_total | ||
| ) | ||
|
|
||
| return MetricSet( | ||
| timestamp=datetime.now(timezone.utc), | ||
| core_util_pct=(float(utilization.gpu) if utilization is not None else None), | ||
| mem_util_pct=( | ||
| float(utilization.memory) if utilization is not None else None | ||
| ), | ||
| mem_total_bytes=(int(memory.total) if memory is not None else None), | ||
| mem_used_bytes=(int(memory.used) if memory is not None else None), | ||
| temp_c=(float(temperature) if temperature is not None else None), | ||
| power_w=(float(power_usage) / 1000.0 if power_usage is not None else None), | ||
| power_limit_w=( | ||
| float(power_limit) / 1000.0 if power_limit is not None else None | ||
| ), | ||
| sm_clock_mhz=(int(clocks.graphics_freq) if clocks is not None else None), | ||
| mem_clock_mhz=(int(clocks.memory_freq) if clocks is not None else None), | ||
| ecc_corrected=(int(ecc_corrected) if ecc_corrected is not None else None), | ||
| ecc_uncorrected=( | ||
| int(ecc_uncorrected) if ecc_uncorrected is not None else None | ||
| ), | ||
| ) | ||
|
|
||
| def get_raw_handle(self, device_id: str) -> Any: | ||
| client = self._ensure_client() | ||
| if device_id in self._handles: | ||
| return self._handles[device_id] | ||
|
|
||
| try: | ||
| index = int(device_id) | ||
| handle = client.get_device_by_index(index) | ||
| self._handles[device_id] = handle | ||
| return handle | ||
| except (ValueError, DeviceTelemetryException) as e: | ||
| raise UnsupportedOperationError( | ||
| f"invalid NVML device id: {device_id}" | ||
| ) from e | ||
|
|
||
| def close(self) -> None: | ||
| client = self._client | ||
| self._client = None | ||
| if client is None: | ||
| return None | ||
|
|
||
| close_method = getattr(client, "close", None) | ||
| if callable(close_method): | ||
| close_method() | ||
| return None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| from dataclasses import dataclass | ||
|
|
||
| from gcm.accelerator.backend import BackendName | ||
|
|
||
|
|
||
| class AcceleratorError(Exception): | ||
| """Base exception type for accelerator HAL failures.""" | ||
|
|
||
|
|
||
| class BackendUnavailableError(AcceleratorError): | ||
| """Raised when backend probe fails due to missing runtime dependencies.""" | ||
|
|
||
|
|
||
| class UnsupportedOperationError(AcceleratorError): | ||
| """Raised when an operation is not implemented by a backend.""" | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class BackendOperationError(AcceleratorError): | ||
| backend: BackendName | ||
| operation: str | ||
|
|
||
| def __str__(self) -> str: | ||
| return f"backend={self.backend.value} operation={self.operation}" |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.