diff --git a/pynvjitlink/patch.py b/pynvjitlink/patch.py index ccaca93..06bac38 100644 --- a/pynvjitlink/patch.py +++ b/pynvjitlink/patch.py @@ -2,6 +2,7 @@ import os import pathlib from functools import partial +import importlib.util from pynvjitlink.api import NvJitLinker, NvJitLinkError @@ -49,31 +50,12 @@ Linker = object -try: - import numba_cuda - +spec = importlib.util.find_spec("numba_cuda") +if spec is not None: _numba_cuda_in_use = True - numba_cuda_ver = tuple(int(x) for x in numba_cuda.__version__.split(".")) - if numba_cuda_ver < max_numba_cuda_ver: - _numba_cuda_version_ok = True - else: - _numba_cuda_version_ok = False - _numba_cuda_error = ( - f"Only `numba_cuda` version below {max_numba_cuda_ver} is supported. " - f"Current version is {numba_cuda.__version__}. " - ) - - if numba_cuda_ver < (0, 2, 0): - suggestion = "Please enable pynvjitlink via NUMBA_CUDA_ENABLE_PYNVJITLINK environment variable." - else: - suggestion = ( - "Please enable pynvjitlink via config.CUDA_ENABLE_PYNVJITLINK option." - ) - - _numba_cuda_error += suggestion -except ImportError: + _numba_cuda_error = "`numba_cuda` includes patches from pynvjitlink, so no further patches are needed." +else: _numba_cuda_in_use = False - _numba_cuda_error = "`numba_cuda` package is not installed. Pynvjitlink should only be used with `numba_cuda` package." class LinkableCode: @@ -289,7 +271,7 @@ def patch_numba_linker(*, lto=False): msg = f"Cannot patch Numba: {_numba_error}" raise RuntimeError(msg) - if not _numba_cuda_in_use or not _numba_cuda_version_ok: + if _numba_cuda_in_use: msg = f"Cannot patch Numba: {_numba_cuda_error}" raise RuntimeError(msg)