Skip to content

Commit 0713ec6

Browse files
committed
Add comments for new code
Signed-off-by: Hui Gao <[email protected]>
1 parent 510eb15 commit 0713ec6

File tree

4 files changed

+46
-11
lines changed

4 files changed

+46
-11
lines changed

tensorrt_llm/_torch/memory_buffer_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from tensorrt_llm.logger import logger
88

9-
from .utils import get_graph_pool
9+
from .utils import get_shared_pool
1010

1111

1212
@dataclass
@@ -86,14 +86,16 @@ def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
8686
# The new buffer is created with uint8 to represent raw bytes.
8787
new_buffer_tensor = None
8888
try:
89-
with torch.cuda.memory.use_mem_pool(get_graph_pool()):
89+
with torch.cuda.memory.use_mem_pool(get_shared_pool()):
9090
new_buffer_tensor = torch.zeros((required_memory_size, ),
9191
device='cuda',
9292
dtype=torch.uint8)
93-
except Exception:
93+
except Exception as ex:
9494
# Need to check if this is an OOM exception
9595
logger.debug(
96-
f"Exception happened to create tensor from given memory pool")
96+
f"Exception happened to create tensor from given memory pool: {str{ex}}"
97+
)
98+
# if exception happens during allocating memory from
9799
new_buffer_tensor = torch.zeros((required_memory_size, ),
98100
device='cuda',
99101
dtype=torch.uint8)

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,12 @@ def needs_capture(self, key: Tuple[int, int, int]):
195195
return key not in self.graph_outputs
196196

197197
def get_graph_pool(self):
198+
"""Returns the CUDA memory pool used by this graph runner.
199+
200+
Returns:
201+
The CUDA memory pool associated with captured graphs, or None if
202+
no graphs have been captured yet.
203+
"""
198204
return self.memory_pool
199205

200206
def capture(self,

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from ..speculative.mtp import SampleStateTensorsMTP
4949
from ..utils import (get_model_extra_attrs,
5050
set_per_request_piecewise_cuda_graph_flag,
51-
set_prefer_mem_pool, set_torch_compiling,
51+
set_shared_mem_pool, set_torch_compiling,
5252
with_model_extra_attrs)
5353
from .config import PyTorchConfig
5454
from .config_utils import is_mla
@@ -2187,7 +2187,7 @@ def forward(
21872187
new_tensors_device, cache_indirection_buffer)
21882188

21892189
self.iter_counter += 1
2190-
with set_prefer_mem_pool(self.cuda_graph_runner.get_graph_pool()):
2190+
with set_shared_mem_pool(self.cuda_graph_runner.get_graph_pool()):
21912191
if not maybe_graph:
21922192
# Fallback to eager execution if graph was not used
21932193
with MoeLoadBalancerIterContext(moe_load_balancer):

tensorrt_llm/_torch/utils.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -317,20 +317,47 @@ def _get_allow_chain_drafter() -> bool:
317317
_buffer_pool = None
318318

319319

320-
def set_mem_pool(buffer_pool):
320+
def set_shared_pool(buffer_pool):
321+
"""Sets the global memory pool for buffer allocation.
322+
323+
Args:
324+
buffer_pool: A CUDA memory pool object to use for allocations.
325+
"""
321326
global _buffer_pool
322327
_buffer_pool = buffer_pool
323328

324329

325-
def get_graph_pool():
330+
def get_shared_pool():
331+
"""Retrieves the current global memory pool.
332+
333+
Returns:
334+
The current memory pool, or None if not set.
335+
"""
326336
global _buffer_pool
327337
return _buffer_pool
328338

329339

330340
@contextlib.contextmanager
331-
def set_prefer_mem_pool(mem_pool):
332-
old_buffer_pool = get_graph_pool()
333-
set_mem_pool(mem_pool)
341+
def set_shared_mem_pool(mem_pool) -> contextlib.AbstractContextManager:
342+
"""Temporarily sets a preferred memory pool and restores the previous one on exit.
343+
344+
This context manager allows temporarily switching to a different memory pool
345+
for CUDA graph operations, ensuring the original pool is restored even if
346+
an exception occurs.
347+
348+
Args:
349+
mem_pool: The memory pool to use within the context.
350+
351+
Yields:
352+
None
353+
354+
Example:
355+
>>> with set_shared_mem_pool(graph_pool):
356+
... # Allocations within this block use graph_pool
357+
... tensor = allocate_buffer(...)
358+
"""
359+
old_buffer_pool = get_shared_pool()
360+
set_shared_pool(mem_pool)
334361
try:
335362
yield
336363
finally:

0 commit comments

Comments
 (0)