Skip to content

Commit 4885ac9

Browse files
made the code more modular
Signed-off-by: Jaya Venkatesh <jjayabaskar@nvidia.com>
1 parent 5a44201 commit 4885ac9

File tree

1 file changed

+16
-21
lines changed

1 file changed

+16
-21
lines changed

rapids_cli/doctor/checks/cuda_toolkit.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,10 @@ def _format_mismatch_error(
4545
source = _get_source_label(found_via)
4646

4747
location = f"CUDA {toolkit_major} toolkit"
48-
if source and cudart_path:
49-
location += f" (found via {source} at {cudart_path})"
50-
elif source:
51-
location += f" (found via {source})"
52-
elif cudart_path:
53-
location += f" (at {cudart_path})"
54-
48+
details = [v for v in (f"found via {source}" if source else None,
49+
f"at {cudart_path}" if cudart_path else None) if v]
50+
if details:
51+
location += f" ({', '.join(details)})"
5552
return (
5653
f"{location} is newer than what the GPU driver supports (CUDA {driver_major}). "
5754
f"Either update the GPU driver to one that supports CUDA {toolkit_major}, "
@@ -126,6 +123,16 @@ def _extract_major_from_cuda_path(path: Path) -> int | None:
126123
return int(match.group(1))
127124
return None
128125

126+
def _check_path_version(label: str, path: Path, driver_major: int) -> None:
127+
"""Raise if a CUDA path points to a version newer than the driver supports."""
128+
major = _extract_major_from_cuda_path(path)
129+
if major is not None and major > driver_major:
130+
raise ValueError(
131+
f"{label} points to CUDA {major} but the GPU driver "
132+
f"only supports up to CUDA {driver_major}. "
133+
f"Update {label} to a CUDA {driver_major}.x installation."
134+
)
135+
129136

130137
def cuda_toolkit_check(verbose=False):
131138
"""Check CUDA toolkit library availability and version consistency."""
@@ -177,25 +184,13 @@ def cuda_toolkit_check(verbose=False):
177184
if uses_system_paths:
178185
# Check /usr/local/cuda symlink
179186
if _CUDA_SYMLINK.exists():
180-
sym_major = _extract_major_from_cuda_path(_CUDA_SYMLINK.resolve())
181-
if sym_major is not None and sym_major > driver_major:
182-
raise ValueError(
183-
f"/usr/local/cuda points to CUDA {sym_major} but the GPU driver "
184-
f"only supports up to CUDA {driver_major}. "
185-
f"Update the symlink to a CUDA {driver_major}.x installation."
186-
)
187+
_check_path_version("/usr/local/cuda", _CUDA_SYMLINK.resolve(), driver_major)
187188

188189
# Check CUDA_HOME / CUDA_PATH
189190
for env_var in ("CUDA_HOME", "CUDA_PATH"):
190191
env_val = os.environ.get(env_var)
191192
if env_val:
192-
env_major = _extract_major_from_cuda_path(Path(env_val))
193-
if env_major is not None and env_major > driver_major:
194-
raise ValueError(
195-
f"{env_var}={env_val} (CUDA {env_major}) but the GPU driver "
196-
f"only supports up to CUDA {driver_major}. "
197-
f"Set {env_var} to a CUDA {driver_major}.x path."
198-
)
193+
_check_path_version(f"{env_var}={env_val}", Path(env_val), driver_major)
199194

200195
if verbose:
201196
version_str = f"CUDA {toolkit_major}" if toolkit_major else "unknown version"

0 commit comments

Comments
 (0)