8
8
import tempfile
9
9
import urllib .request
10
10
import warnings
11
- from contextlib import contextmanager
12
11
from dataclasses import fields , replace
13
12
from enum import Enum
14
13
from pathlib import Path
15
14
from typing import (
16
15
Any ,
17
16
Callable ,
18
17
Dict ,
19
- Iterator ,
20
18
List ,
21
19
Optional ,
22
20
Sequence ,
@@ -864,40 +862,52 @@ def is_platform_supported_for_trtllm(platform: str) -> bool:
864
862
return True
865
863
866
864
867
- @contextmanager
868
- def download_plugin_lib_path (platform : str ) -> Iterator [str ]:
869
- """
870
- Downloads (if needed) and extracts the TensorRT-LLM plugin wheel for the specified platform,
871
- then yields the path to the extracted shared library (.so or .dll).
865
+ def _cache_root () -> Path :
866
+ username = getpass .getuser ()
867
+ return Path (tempfile .gettempdir ()) / f"torch_tensorrt_{ username } "
872
868
873
- The wheel file is cached in a user-specific temporary directory to avoid repeated downloads.
874
- Extraction happens in a temporary directory that is cleaned up after use.
875
869
876
- Args :
877
- platform (str): The platform identifier string (e.g., 'linux_x86_64') to select the correct wheel.
870
+ def _extracted_dir_trtllm ( platform : str ) -> Path :
871
+ return _cache_root () / "trtllm" / f" { __tensorrt_llm_version__ } _ { platform } "
878
872
879
- Yields:
880
- str: The full path to the extracted TensorRT-LLM shared library file.
881
873
882
- Raises:
883
- ImportError: If the 'zipfile' module is not available.
884
- RuntimeError: If the wheel file is missing, corrupted, or extraction fails.
874
+ def download_and_get_plugin_lib_path (platform : str ) -> Optional [str ]:
885
875
"""
886
- plugin_lib_path = None
887
- username = getpass .getuser ()
888
- torchtrt_cache_dir = Path (tempfile .gettempdir ()) / f"torch_tensorrt_{ username } "
889
- torchtrt_cache_dir .mkdir (parents = True , exist_ok = True )
890
- file_name = f"tensorrt_llm-{ __tensorrt_llm_version__ } -{ _WHL_CPYTHON_VERSION } -{ _WHL_CPYTHON_VERSION } -{ platform } .whl"
891
- torchtrt_cache_trtllm_whl = torchtrt_cache_dir / file_name
892
- downloaded_file_path = torchtrt_cache_trtllm_whl
893
-
894
- if not torchtrt_cache_trtllm_whl .exists ():
895
- # Downloading TRT-LLM lib
876
+ Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary.
877
+
878
+ Args:
879
+ platform (str): Platform identifier (e.g., 'linux_x86_64')
880
+
881
+ Returns:
882
+ Optional[str]: Path to shared library or None if operation fails.
883
+ """
884
+ wheel_filename = (
885
+ f"tensorrt_llm-{ __tensorrt_llm_version__ } -{ _WHL_CPYTHON_VERSION } -"
886
+ f"{ _WHL_CPYTHON_VERSION } -{ platform } .whl"
887
+ )
888
+ wheel_path = _cache_root () / wheel_filename
889
+ extract_dir = _extracted_dir_trtllm (platform )
890
+ # else will never be met though
891
+ lib_filename = (
892
+ "libnvinfer_plugin_tensorrt_llm.so"
893
+ if "linux" in platform
894
+ else "libnvinfer_plugin_tensorrt_llm.dll"
895
+ )
896
+ # eg: /tmp/torch_tensorrt_<username>/trtllm/0.17.0.post1_linux_x86_64/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so
897
+ plugin_lib_path = extract_dir / "tensorrt_llm" / "libs" / lib_filename
898
+
899
+ if plugin_lib_path .exists ():
900
+ return str (plugin_lib_path )
901
+
902
+ wheel_path .parent .mkdir (parents = True , exist_ok = True )
903
+ extract_dir .mkdir (parents = True , exist_ok = True )
904
+
905
+ if not wheel_path .exists ():
896
906
base_url = "https://pypi.nvidia.com/tensorrt-llm/"
897
- download_url = base_url + file_name
907
+ download_url = base_url + wheel_filename
898
908
try :
899
- logger .debug (f "Downloading { download_url } ..." )
900
- urllib .request .urlretrieve (download_url , downloaded_file_path )
909
+ logger .debug ("Downloading %s ..." , download_url )
910
+ urllib .request .urlretrieve (download_url , wheel_path )
901
911
logger .debug ("Download succeeded and TRT-LLM wheel is now present" )
902
912
except urllib .error .HTTPError as e :
903
913
logger .error (
@@ -910,41 +920,45 @@ def download_plugin_lib_path(platform: str) -> Iterator[str]:
910
920
except OSError as e :
911
921
logger .error (f"Local file write error: { e } " )
912
922
913
- # Proceeding with the unzip of the wheel file in tmpdir
914
- if "linux" in platform :
915
- lib_filename = "libnvinfer_plugin_tensorrt_llm.so"
916
- else :
917
- # This condition is never met though
918
- lib_filename = "libnvinfer_plugin_tensorrt_llm.dll"
923
+ try :
924
+ import zipfile
925
+ except ImportError as e :
926
+ raise ImportError (
927
+ "zipfile module is required but not found. Please install zipfile"
928
+ )
929
+ try :
930
+ with zipfile .ZipFile (wheel_path ) as zip_ref :
931
+ zip_ref .extractall (extract_dir )
932
+ logger .debug (f"Extracted wheel to { extract_dir } " )
933
+ except FileNotFoundError as e :
934
+ # This should capture the errors in the download failure above
935
+ logger .error (f"Wheel file not found at { wheel_path } : { e } " )
936
+ raise RuntimeError (
937
+ f"Failed to find downloaded wheel file at { wheel_path } "
938
+ ) from e
939
+ except zipfile .BadZipFile as e :
940
+ logger .error (f"Invalid or corrupted wheel file: { e } " )
941
+ raise RuntimeError (
942
+ "Downloaded wheel file is corrupted or not a valid zip archive"
943
+ ) from e
944
+ except Exception as e :
945
+ logger .error (f"Unexpected error while extracting wheel: { e } " )
946
+ raise RuntimeError (
947
+ "Unexpected error during extraction of TensorRT-LLM wheel"
948
+ ) from e
919
949
920
- with tempfile .TemporaryDirectory () as tmpdir :
921
- try :
922
- import zipfile
923
- except ImportError :
924
- raise ImportError (
925
- "zipfile module is required but not found. Please install zipfile"
926
- )
927
- try :
928
- with zipfile .ZipFile (downloaded_file_path , "r" ) as zip_ref :
929
- zip_ref .extractall (tmpdir ) # Extract to a folder named 'tensorrt_llm'
930
- except FileNotFoundError as e :
931
- # This should capture the errors in the download failure above
932
- logger .error (f"Wheel file not found at { downloaded_file_path } : { e } " )
933
- raise RuntimeError (
934
- f"Failed to find downloaded wheel file at { downloaded_file_path } "
935
- ) from e
936
- except zipfile .BadZipFile as e :
937
- logger .error (f"Invalid or corrupted wheel file: { e } " )
938
- raise RuntimeError (
939
- "Downloaded wheel file is corrupted or not a valid zip archive"
940
- ) from e
941
- except Exception as e :
942
- logger .error (f"Unexpected error while extracting wheel: { e } " )
943
- raise RuntimeError (
944
- "Unexpected error during extraction of TensorRT-LLM wheel"
945
- ) from e
946
- plugin_lib_path = os .path .join (tmpdir , "tensorrt_llm/libs" , lib_filename )
947
- yield plugin_lib_path
950
+ try :
951
+ wheel_path .unlink (missing_ok = True )
952
+ logger .debug (f"Deleted wheel file: { wheel_path } " )
953
+ except Exception as e :
954
+ logger .warning (f"Could not delete wheel file { wheel_path } : { e } " )
955
+ if not plugin_lib_path .exists ():
956
+ logger .error (
957
+ f"Plugin library not found at expected location: { plugin_lib_path } "
958
+ )
959
+ return None
960
+
961
+ return str (plugin_lib_path )
948
962
949
963
950
964
def load_and_initialize_trtllm_plugin (plugin_lib_path : str ) -> bool :
@@ -1034,6 +1048,6 @@ def load_tensorrt_llm_for_nccl() -> bool:
1034
1048
)
1035
1049
return False
1036
1050
1037
- with download_plugin_lib_path (platform ) as plugin_lib_path :
1038
- return load_and_initialize_trtllm_plugin (plugin_lib_path )
1051
+ plugin_lib_path = download_and_get_plugin_lib_path (platform )
1052
+ return load_and_initialize_trtllm_plugin (plugin_lib_path ) # type: ignore[arg-type]
1039
1053
return False
0 commit comments