Skip to content

Commit 335acc3

Browse files
committed
Reuse memory_pool
Signed-off-by: Hui Gao <[email protected]>
1 parent 9298f1b commit 335acc3

File tree

5 files changed

+272
-28
lines changed

5 files changed

+272
-28
lines changed

tensorrt_llm/_torch/memory_buffer_utils.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
import torch
66

7+
from tensorrt_llm.logger import logger
8+
9+
from .utils import get_shared_pool
10+
711

812
@dataclass
913
class BufferBlock:
@@ -80,9 +84,22 @@ def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
8084

8185
# No suitable buffer was found, so allocate a new one.
8286
# The new buffer is created with uint8 to represent raw bytes.
83-
new_buffer_tensor = torch.zeros((required_memory_size, ),
84-
device='cuda',
85-
dtype=torch.uint8)
87+
new_buffer_tensor = None
88+
try:
89+
with torch.cuda.memory.use_mem_pool(get_shared_pool()):
90+
new_buffer_tensor = torch.zeros((required_memory_size, ),
91+
device='cuda',
92+
dtype=torch.uint8)
93+
except Exception as ex:
94+
# Need to check if this is an OOM exception
95+
logger.debug(
96+
f"Exception happened to create tensor from given memory pool: {str(ex)}"
97+
)
98+
# if exception happens during allocating memory from
99+
new_buffer_tensor = torch.zeros((required_memory_size, ),
100+
device='cuda',
101+
dtype=torch.uint8)
102+
86103
new_block = BufferBlock(buffer=new_buffer_tensor,
87104
is_reserved=reserve_buffer)
88105

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,15 @@ def needs_capture(self, key: Tuple[int, int, int]):
194194

195195
return key not in self.graph_outputs
196196

197+
def get_graph_pool(self):
198+
"""Returns the CUDA memory pool used by this graph runner.
199+
200+
Returns:
201+
The CUDA memory pool associated with captured graphs, or None if
202+
no graphs have been captured yet.
203+
"""
204+
return self.memory_pool
205+
197206
def capture(self,
198207
key: Tuple[int, int, int],
199208
forward_fn: Callable,
@@ -255,6 +264,7 @@ def _setup_spec_decoding_and_forward(key: Tuple[int, int, int],
255264
capture_inputs)
256265
if postprocess_fn is not None:
257266
postprocess_fn(capture_inputs)
267+
258268
with torch.cuda.graph(graph, pool=self.memory_pool):
259269
output = _setup_spec_decoding_and_forward(
260270
key, forward_fn, capture_inputs)

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@
4848
from ..speculative.mtp import SampleStateTensorsMTP
4949
from ..utils import (get_model_extra_attrs,
5050
set_per_request_piecewise_cuda_graph_flag,
51-
set_torch_compiling, with_model_extra_attrs)
51+
set_shared_mem_pool, set_torch_compiling,
52+
with_model_extra_attrs)
5253
from .config import PyTorchConfig
5354
from .config_utils import is_mla
5455
from .cuda_graph_runner import CUDAGraphRunner
@@ -2196,35 +2197,35 @@ def forward(
21962197
new_tensors_device, cache_indirection_buffer)
21972198

21982199
self.iter_counter += 1
2200+
with set_shared_mem_pool(self.cuda_graph_runner.get_graph_pool()):
2201+
if not maybe_graph:
2202+
# Fallback to eager execution if graph was not used
2203+
with MoeLoadBalancerIterContext(moe_load_balancer):
2204+
outputs = self._forward_step(inputs, gather_ids,
2205+
gather_context_logits)
2206+
else:
2207+
if self.cuda_graph_runner.needs_capture(key):
21992208

2200-
if not maybe_graph:
2201-
# Fallback to eager execution if graph was not used
2202-
with MoeLoadBalancerIterContext(moe_load_balancer):
2203-
outputs = self._forward_step(inputs, gather_ids,
2204-
gather_context_logits)
2205-
else:
2206-
if self.cuda_graph_runner.needs_capture(key):
2207-
2208-
def capture_forward_fn(inputs: Dict[str, Any]):
2209-
with MoeLoadBalancerIterContext(moe_load_balancer):
2210-
return self._forward_step(
2211-
inputs,
2212-
gather_ids=gather_ids,
2213-
gather_context_logits=gather_context_logits)
2209+
def capture_forward_fn(inputs: Dict[str, Any]):
2210+
with MoeLoadBalancerIterContext(moe_load_balancer):
2211+
return self._forward_step(
2212+
inputs,
2213+
gather_ids=gather_ids,
2214+
gather_context_logits=gather_context_logits)
22142215

2215-
def capture_postprocess_fn(inputs: Dict[str, Any]):
2216-
self._postprocess_inputs(inputs)
2216+
def capture_postprocess_fn(inputs: Dict[str, Any]):
2217+
self._postprocess_inputs(inputs)
22172218

2218-
self.cuda_graph_runner.capture(key, capture_forward_fn,
2219-
inputs,
2220-
capture_postprocess_fn)
2219+
self.cuda_graph_runner.capture(key, capture_forward_fn,
2220+
inputs,
2221+
capture_postprocess_fn)
22212222

2222-
# here we don't need to use context since cuda graph capture didn't run kernel.
2223-
# maybe we need a cleaner way to do this.
2224-
outputs = self.cuda_graph_runner.replay(key, inputs)
2225-
else:
2226-
with MoeLoadBalancerIterContext(moe_load_balancer):
2223+
# here we don't need to use context since cuda graph capture didn't run kernel.
2224+
# maybe we need a cleaner way to do this.
22272225
outputs = self.cuda_graph_runner.replay(key, inputs)
2226+
else:
2227+
with MoeLoadBalancerIterContext(moe_load_balancer):
2228+
outputs = self.cuda_graph_runner.replay(key, inputs)
22282229

22292230
self._execute_logit_post_processors(scheduled_requests, outputs)
22302231

tensorrt_llm/_torch/utils.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,3 +320,53 @@ def get_device_uuid(device_idx: int) -> str:
320320
property = torch.cuda.get_device_properties(device_idx)
321321
uuid = "GPU-" + str(property.uuid)
322322
return uuid
323+
324+
325+
_buffer_pool = None
326+
327+
328+
def set_shared_pool(buffer_pool):
329+
"""Sets the global memory pool for buffer allocation.
330+
331+
Args:
332+
buffer_pool: A CUDA memory pool object to use for allocations.
333+
"""
334+
global _buffer_pool
335+
_buffer_pool = buffer_pool
336+
337+
338+
def get_shared_pool():
339+
"""Retrieves the current global memory pool.
340+
341+
Returns:
342+
The current memory pool, or None if not set.
343+
"""
344+
global _buffer_pool
345+
return _buffer_pool
346+
347+
348+
@contextlib.contextmanager
349+
def set_shared_mem_pool(mem_pool) -> contextlib.AbstractContextManager:
350+
"""Temporarily sets a preferred memory pool and restores the previous one on exit.
351+
352+
This context manager allows temporarily switching to a different memory pool
353+
for CUDA graph operations, ensuring the original pool is restored even if
354+
an exception occurs.
355+
356+
Args:
357+
mem_pool: The memory pool to use within the context.
358+
359+
Yields:
360+
None
361+
362+
Example:
363+
>>> with set_shared_mem_pool(buffer_pool):
364+
... # Allocations within this block use buffer_pool
365+
... tensor = allocate_buffer(...)
366+
"""
367+
old_buffer_pool = get_shared_pool()
368+
set_shared_pool(mem_pool)
369+
try:
370+
yield
371+
finally:
372+
set_shared_pool(old_buffer_pool)
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
#!/bin/bash
2+
3+
set -ex
4+
5+
TRT_VER="10.13.2.6"
6+
# Align with the pre-installed cuDNN / cuBLAS / NCCL versions from
7+
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-25-08.html#rel-25-08
8+
CUDA_VER="13.0" # 13.0.0
9+
# Keep the installation for cuDNN if users want to install PyTorch with source codes.
10+
# PyTorch 2.x can compile with cuDNN v9.
11+
CUDNN_VER="9.12.0.46-1"
12+
# NCCL version 2.26.x used in the NGC PyTorch 25.05 image but has a performance regression issue.
13+
# Use NCCL version 2.27.5 which has the fixes.
14+
NCCL_VER="2.27.7-1+cuda13.0"
15+
# Use cuBLAS version 13.0.0.19 instead.
16+
CUBLAS_VER="13.0.0.19-1"
17+
# Align with the pre-installed CUDA / NVCC / NVRTC versions from
18+
# https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
19+
NVRTC_VER="13.0.48-1"
20+
CUDA_RUNTIME="13.0.48-1"
21+
CUDA_DRIVER_VERSION="580.65.06-1.el8"
22+
23+
for i in "$@"; do
24+
case $i in
25+
--TRT_VER=?*) TRT_VER="${i#*=}";;
26+
--CUDA_VER=?*) CUDA_VER="${i#*=}";;
27+
--CUDNN_VER=?*) CUDNN_VER="${i#*=}";;
28+
--NCCL_VER=?*) NCCL_VER="${i#*=}";;
29+
--CUBLAS_VER=?*) CUBLAS_VER="${i#*=}";;
30+
*) ;;
31+
esac
32+
shift
33+
done
34+
35+
NVCC_VERSION_OUTPUT=$(nvcc --version)
36+
if [[ $(echo $NVCC_VERSION_OUTPUT | grep -oP "\d+\.\d+" | head -n 1) != ${CUDA_VER} ]]; then
37+
echo "The version of pre-installed CUDA is not equal to ${CUDA_VER}."
38+
fi
39+
40+
install_ubuntu_requirements() {
41+
apt-get update && apt-get install -y --no-install-recommends gnupg2 curl ca-certificates
42+
ARCH=$(uname -m)
43+
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
44+
if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi
45+
46+
curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/${ARCH}/cuda-keyring_1.1-1_all.deb
47+
dpkg -i cuda-keyring_1.1-1_all.deb
48+
rm cuda-keyring_1.1-1_all.deb
49+
50+
apt-get update
51+
if [[ $(apt list --installed | grep libcudnn9) ]]; then
52+
apt-get remove --purge -y libcudnn9*
53+
fi
54+
if [[ $(apt list --installed | grep libnccl) ]]; then
55+
apt-get remove --purge -y --allow-change-held-packages libnccl*
56+
fi
57+
if [[ $(apt list --installed | grep libcublas) ]]; then
58+
apt-get remove --purge -y --allow-change-held-packages libcublas*
59+
fi
60+
if [[ $(apt list --installed | grep cuda-nvrtc-dev) ]]; then
61+
apt-get remove --purge -y --allow-change-held-packages cuda-nvrtc-dev*
62+
fi
63+
64+
CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
65+
NVRTC_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
66+
67+
apt-get install -y --no-install-recommends \
68+
libcudnn9-cuda-13=${CUDNN_VER} \
69+
libcudnn9-dev-cuda-13=${CUDNN_VER} \
70+
libcudnn9-headers-cuda-13=${CUDNN_VER} \
71+
libnccl2=${NCCL_VER} \
72+
libnccl-dev=${NCCL_VER} \
73+
libcublas-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} \
74+
libcublas-dev-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} \
75+
cuda-nvrtc-dev-${NVRTC_CUDA_VERSION}=${NVRTC_VER}
76+
77+
apt-get clean
78+
rm -rf /var/lib/apt/lists/*
79+
}
80+
81+
install_rockylinux_requirements() {
82+
CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
83+
84+
ARCH=$(uname -m)
85+
if [ "$ARCH" = "x86_64" ];then ARCH1="x86_64" && ARCH2="x64" && ARCH3=$ARCH1;fi
86+
if [ "$ARCH" = "aarch64" ];then ARCH1="aarch64" && ARCH2="aarch64sbsa" && ARCH3="sbsa";fi
87+
88+
# Download and install packages
89+
for pkg in \
90+
"libnccl-${NCCL_VER}.${ARCH1}" \
91+
"libnccl-devel-${NCCL_VER}.${ARCH1}" \
92+
"cuda-compat-${CUBLAS_CUDA_VERSION}-${CUDA_DRIVER_VERSION}.${ARCH1}" \
93+
"cuda-toolkit-${CUBLAS_CUDA_VERSION}-config-common-${CUDA_RUNTIME}.noarch" \
94+
"cuda-toolkit-13-config-common-${CUDA_RUNTIME}.noarch" \
95+
"cuda-toolkit-config-common-${CUDA_RUNTIME}.noarch" \
96+
"libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}.${ARCH1}" \
97+
"libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}.${ARCH1}"; do
98+
wget --retry-connrefused --timeout=180 --tries=10 --continue "https://developer.download.nvidia.com/compute/cuda/repos/rhel8/${ARCH3}/${pkg}.rpm"
99+
done
100+
101+
# Remove old packages
102+
dnf remove -y "libnccl*"
103+
104+
# Install new packages
105+
dnf -y install \
106+
libnccl-${NCCL_VER}.${ARCH1}.rpm \
107+
libnccl-devel-${NCCL_VER}.${ARCH1}.rpm \
108+
cuda-compat-${CUBLAS_CUDA_VERSION}-${CUDA_DRIVER_VERSION}.${ARCH1}.rpm \
109+
cuda-toolkit-${CUBLAS_CUDA_VERSION}-config-common-${CUDA_RUNTIME}.noarch.rpm \
110+
cuda-toolkit-13-config-common-${CUDA_RUNTIME}.noarch.rpm \
111+
cuda-toolkit-config-common-${CUDA_RUNTIME}.noarch.rpm \
112+
libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}.${ARCH1}.rpm \
113+
libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}.${ARCH1}.rpm
114+
115+
# Clean up
116+
rm -f *.rpm
117+
dnf clean all
118+
nvcc --version
119+
}
120+
121+
install_tensorrt() {
122+
PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')
123+
PARSED_PY_VERSION=$(echo "${PY_VERSION//./}")
124+
TRT_CUDA_VERSION=${CUDA_VER}
125+
TRT_VER_SHORT=$(echo $TRT_VER | cut -d. -f1-3)
126+
127+
if [ -z "$RELEASE_URL_TRT" ];then
128+
ARCH=${TRT_TARGETARCH}
129+
if [ -z "$ARCH" ];then ARCH=$(uname -m);fi
130+
if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi
131+
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
132+
RELEASE_URL_TRT="https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/${TRT_VER_SHORT}/tars/TensorRT-${TRT_VER}.Linux.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz"
133+
fi
134+
135+
wget --retry-connrefused --timeout=180 --tries=10 --continue ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar
136+
tar -xf /tmp/TensorRT.tar -C /usr/local/
137+
mv /usr/local/TensorRT-${TRT_VER} /usr/local/tensorrt
138+
pip3 install --no-cache-dir /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl
139+
rm -rf /tmp/TensorRT.tar
140+
echo 'export LD_LIBRARY_PATH=/usr/local/tensorrt/lib:$LD_LIBRARY_PATH' >> "${ENV}"
141+
142+
rm -f /usr/local/tensorrt/lib/libnvinfer_vc_plugin_static.a \
143+
/usr/local/tensorrt/lib/libnvinfer_plugin_static.a \
144+
/usr/local/tensorrt/lib/libnvinfer_static.a \
145+
/usr/local/tensorrt/lib/libnvinfer_dispatch_static.a \
146+
/usr/local/tensorrt/lib/libnvinfer_lean_static.a \
147+
/usr/local/tensorrt/lib/libnvonnxparser_static.a \
148+
/usr/local/tensorrt/lib/libnvinfer_builder_resource_win.so.10.10.0
149+
}
150+
151+
# Install base packages depending on the base OS
152+
ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"')
153+
case "$ID" in
154+
ubuntu)
155+
install_ubuntu_requirements
156+
install_tensorrt
157+
;;
158+
rocky)
159+
install_rockylinux_requirements
160+
install_tensorrt
161+
;;
162+
*)
163+
echo "Unable to determine OS..."
164+
exit 1
165+
;;
166+
esac

0 commit comments

Comments
 (0)