Skip to content

Commit

Permalink
raise when numba_cuda is in use
Browse files Browse the repository at this point in the history
  • Loading branch information
isVoid committed Feb 18, 2025
1 parent b435aba commit a1fdff6
Showing 1 changed file with 6 additions and 24 deletions.
30 changes: 6 additions & 24 deletions pynvjitlink/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import pathlib
from functools import partial
import importlib.util

from pynvjitlink.api import NvJitLinker, NvJitLinkError

Expand Down Expand Up @@ -49,31 +50,12 @@
Linker = object


try:
import numba_cuda

spec = importlib.util.find_spec("numba_cuda")
if spec is not None:
_numba_cuda_in_use = True
numba_cuda_ver = tuple(int(x) for x in numba_cuda.__version__.split("."))
if numba_cuda_ver < max_numba_cuda_ver:
_numba_cuda_version_ok = True
else:
_numba_cuda_version_ok = False
_numba_cuda_error = (
f"Only `numba_cuda` version below {max_numba_cuda_ver} is supported. "
f"Current version is {numba_cuda.__version__}. "
)

if numba_cuda_ver < (0, 2, 0):
suggestion = "Please enable pynvjitlink via NUMBA_CUDA_ENABLE_PYNVJITLINK environment variable."
else:
suggestion = (
"Please enable pynvjitlink via config.CUDA_ENABLE_PYNVJITLINK option."
)

_numba_cuda_error += suggestion
except ImportError:
_numba_cuda_error = "`numba_cuda` includes patches from pynvjitlink, so no further patches are needed."
else:
_numba_cuda_in_use = False
_numba_cuda_error = "`numba_cuda` package is not installed. Pynvjitlink should only be used with `numba_cuda` package."


class LinkableCode:
Expand Down Expand Up @@ -289,7 +271,7 @@ def patch_numba_linker(*, lto=False):
msg = f"Cannot patch Numba: {_numba_error}"
raise RuntimeError(msg)

if not _numba_cuda_in_use or not _numba_cuda_version_ok:
if _numba_cuda_in_use:
msg = f"Cannot patch Numba: {_numba_cuda_error}"
raise RuntimeError(msg)

Expand Down

0 comments on commit a1fdff6

Please sign in to comment.