Skip to content

Commit 18b76a9

Browse files
committed
Address comments
Signed-off-by: Hui Gao <[email protected]>
1 parent bc94a21 commit 18b76a9

File tree

3 files changed

+51
-53
lines changed

3 files changed

+51
-53
lines changed

tensorrt_llm/_torch/memory_buffer_utils.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
from tensorrt_llm.logger import logger
88

9-
from .utils import get_shared_pool
10-
119

1210
@dataclass
1311
class BufferBlock:
@@ -115,3 +113,52 @@ def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
115113
def get_memory_buffers():
116114
global _buffer
117115
return _buffer
116+
117+
118+
_shared_pool = None
119+
120+
121+
def set_shared_pool(shared_pool):
122+
"""Sets the global memory pool for buffer allocation.
123+
124+
Args:
125+
shared_pool: A CUDA memory pool object to use for allocations.
126+
"""
127+
global _shared_pool
128+
_shared_pool = shared_pool
129+
130+
131+
def get_shared_pool():
132+
"""Retrieves the current global memory pool.
133+
134+
Returns:
135+
The current memory pool, or None if not set.
136+
"""
137+
return _shared_pool
138+
139+
140+
@contextlib.contextmanager
141+
def with_shared_pool(shared_pool) -> contextlib.AbstractContextManager:
142+
"""Temporarily sets a preferred memory pool and restores the previous one on exit.
143+
144+
This context manager allows temporarily switching to a different memory pool
145+
for CUDA graph operations, ensuring the original pool is restored even if
146+
an exception occurs.
147+
148+
Args:
149+
shared_pool: The memory pool to use within the context.
150+
151+
Yields:
152+
None
153+
154+
Example:
155+
>>> with with_shared_pool(shared_pool):
156+
... # Allocations within this block use shared_pool
157+
... tensor = allocate_buffer(...)
158+
"""
159+
old_shared_pool = get_shared_pool()
160+
set_shared_pool(shared_pool)
161+
try:
162+
yield
163+
finally:
164+
set_shared_pool(old_shared_pool)

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from ..distributed import MPIDist
3535
from ..distributed.communicator import init_pp_comm
3636
from ..expert_statistic import ExpertStatistic
37+
from ..memory_buffer_utils import with_shared_pool
3738
from ..metadata import KVCacheParams
3839
from ..models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader
3940
from ..models.modeling_multimodal_utils import filter_mm_token_from_input_ids
@@ -48,8 +49,7 @@
4849
from ..speculative.mtp import SampleStateTensorsMTP
4950
from ..utils import (get_model_extra_attrs,
5051
set_per_request_piecewise_cuda_graph_flag,
51-
set_torch_compiling, with_model_extra_attrs,
52-
with_shared_pool)
52+
set_torch_compiling, with_model_extra_attrs)
5353
from .config import PyTorchConfig
5454
from .config_utils import is_mla
5555
from .cuda_graph_runner import CUDAGraphRunner

tensorrt_llm/_torch/utils.py

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -314,52 +314,3 @@ def get_device_uuid(device_idx: int) -> str:
314314
property = torch.cuda.get_device_properties(device_idx)
315315
uuid = "GPU-" + str(property.uuid)
316316
return uuid
317-
318-
319-
_buffer_pool = None
320-
321-
322-
def set_shared_pool(buffer_pool):
323-
"""Sets the global memory pool for buffer allocation.
324-
325-
Args:
326-
buffer_pool: A CUDA memory pool object to use for allocations.
327-
"""
328-
global _buffer_pool
329-
_buffer_pool = buffer_pool
330-
331-
332-
def get_shared_pool():
333-
"""Retrieves the current global memory pool.
334-
335-
Returns:
336-
The current memory pool, or None if not set.
337-
"""
338-
return _buffer_pool
339-
340-
341-
@contextlib.contextmanager
342-
def with_shared_pool(buffer_pool) -> contextlib.AbstractContextManager:
343-
"""Temporarily sets a preferred memory pool and restores the previous one on exit.
344-
345-
This context manager allows temporarily switching to a different memory pool
346-
for CUDA graph operations, ensuring the original pool is restored even if
347-
an exception occurs.
348-
349-
Args:
350-
mem_pool: The memory pool to use within the context.
351-
352-
Yields:
353-
None
354-
355-
Example:
356-
>>> with with_shared_pool(buffer_pool):
357-
... # Allocations within this block use buffer_pool
358-
... tensor = allocate_buffer(...)
359-
"""
360-
old_buffer_pool = get_shared_pool()
361-
set_shared_pool(mem_pool)
362-
try:
363-
yield
364-
finally:
365-
set_shared_pool(old_buffer_pool)

0 commit comments

Comments
 (0)