diff --git a/src/bridge/gs_madrona/renderer_gs.py b/src/bridge/gs_madrona/renderer_gs.py index cf664ce0..af27522b 100644 --- a/src/bridge/gs_madrona/renderer_gs.py +++ b/src/bridge/gs_madrona/renderer_gs.py @@ -1,5 +1,6 @@ import os import ctypes +from importlib.metadata import distribution, PackageNotFoundError from pathlib import Path from typing import Tuple @@ -62,11 +63,12 @@ def __init__( # Preload Nvidia compiler runtime if available (i.e. torch is not built from source) try: - import nvidia.cuda_nvrtc - nvrtc_dir = Path(nvidia.cuda_nvrtc.__file__).parent.absolute() - libnvrtc_path, *_ = filter(Path.is_file, (nvrtc_dir / "lib").glob("libnvrtc.so.1*")) - ctypes.CDLL(libnvrtc_path, ctypes.RTLD_LOCAL) - except ImportError: + dist = distribution("nvidia_cuda_nvrtc_cu12") + for file in dist.files: + if file.name.startswith("libnvrtc.so.1"): + ctypes.CDLL(dist.locate_file(file), ctypes.RTLD_LOCAL) + break + except PackageNotFoundError: pass self.madrona = MadronaBatchRenderer(