Skip to content

Commit dc1cf86

Browse files
HuiGao-NVyufeiwu-nv
authored andcommitted
[None][feat] reuse cudagraph memory pool in normal forward flow (NVIDIA#8095)
Signed-off-by: Hui Gao <[email protected]> Signed-off-by: yufeiwu-nv <[email protected]>
1 parent 4656c4a commit dc1cf86

File tree

3 files changed

+104
-27
lines changed

3 files changed

+104
-27
lines changed

tensorrt_llm/_torch/memory_buffer_utils.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import contextlib
12
import math
23
from dataclasses import dataclass
34
from typing import Optional
45

56
import torch
67

8+
from tensorrt_llm.logger import logger
9+
710

811
@dataclass
912
class BufferBlock:
@@ -80,9 +83,23 @@ def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
8083

8184
# No suitable buffer was found, so allocate a new one.
8285
# The new buffer is created with uint8 to represent raw bytes.
83-
new_buffer_tensor = torch.zeros((required_memory_size, ),
84-
device='cuda',
85-
dtype=torch.uint8)
86+
new_buffer_tensor = None
87+
try:
88+
with torch.cuda.memory.use_mem_pool(get_shared_pool()):
89+
new_buffer_tensor = torch.zeros((required_memory_size, ),
90+
device='cuda',
91+
dtype=torch.uint8)
92+
except Exception as ex:
93+
# Need to check if this is an OOM exception
94+
logger.debug(
95+
f"Exception happened to create tensor from given memory pool: {str(ex)}"
96+
)
97+
# if exception happens during allocating memory from shared pool, retry
98+
# to allocate from default pool
99+
new_buffer_tensor = torch.zeros((required_memory_size, ),
100+
device='cuda',
101+
dtype=torch.uint8)
102+
86103
new_block = BufferBlock(buffer=new_buffer_tensor,
87104
is_reserved=reserve_buffer)
88105

@@ -97,3 +114,52 @@ def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
97114
def get_memory_buffers():
98115
global _buffer
99116
return _buffer
117+
118+
119+
_shared_pool = None
120+
121+
122+
def set_shared_pool(shared_pool):
123+
"""Sets the global memory pool for buffer allocation.
124+
125+
Args:
126+
shared_pool: A CUDA memory pool object to use for allocations.
127+
"""
128+
global _shared_pool
129+
_shared_pool = shared_pool
130+
131+
132+
def get_shared_pool():
133+
"""Retrieves the current global memory pool.
134+
135+
Returns:
136+
The current memory pool, or None if not set.
137+
"""
138+
return _shared_pool
139+
140+
141+
@contextlib.contextmanager
142+
def with_shared_pool(shared_pool) -> contextlib.AbstractContextManager:
143+
"""Temporarily sets a preferred memory pool and restores the previous one on exit.
144+
145+
This context manager allows temporarily switching to a different memory pool
146+
for CUDA graph operations, ensuring the original pool is restored even if
147+
an exception occurs.
148+
149+
Args:
150+
shared_pool: The memory pool to use within the context.
151+
152+
Yields:
153+
None
154+
155+
Example:
156+
>>> with with_shared_pool(shared_pool):
157+
... # Allocations within this block use shared_pool
158+
... tensor = allocate_buffer(...)
159+
"""
160+
old_shared_pool = get_shared_pool()
161+
set_shared_pool(shared_pool)
162+
try:
163+
yield
164+
finally:
165+
set_shared_pool(old_shared_pool)

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,15 @@ def needs_capture(self, key: Tuple[int, int, int]):
194194

195195
return key not in self.graph_outputs
196196

197+
def get_graph_pool(self):
198+
"""Returns the CUDA memory pool used by this graph runner.
199+
200+
Returns:
201+
The CUDA memory pool associated with captured graphs, or None if
202+
no graphs have been captured yet.
203+
"""
204+
return self.memory_pool
205+
197206
def capture(self,
198207
key: Tuple[int, int, int],
199208
forward_fn: Callable,
@@ -255,6 +264,7 @@ def _setup_spec_decoding_and_forward(key: Tuple[int, int, int],
255264
capture_inputs)
256265
if postprocess_fn is not None:
257266
postprocess_fn(capture_inputs)
267+
258268
with torch.cuda.graph(graph, pool=self.memory_pool):
259269
output = _setup_spec_decoding_and_forward(
260270
key, forward_fn, capture_inputs)

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from ..distributed import MPIDist
3535
from ..distributed.communicator import init_pp_comm
3636
from ..expert_statistic import ExpertStatistic
37+
from ..memory_buffer_utils import with_shared_pool
3738
from ..metadata import KVCacheParams
3839
from ..models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader
3940
from ..models.modeling_multimodal_utils import filter_mm_token_from_input_ids
@@ -2280,35 +2281,35 @@ def forward(
22802281
new_tensors_device, cache_indirection_buffer)
22812282

22822283
self.iter_counter += 1
2284+
with with_shared_pool(self.cuda_graph_runner.get_graph_pool()):
2285+
if not maybe_graph:
2286+
# Fallback to eager execution if graph was not used
2287+
with MoeLoadBalancerIterContext(moe_load_balancer):
2288+
outputs = self._forward_step(inputs, gather_ids,
2289+
gather_context_logits)
2290+
else:
2291+
if self.cuda_graph_runner.needs_capture(key):
22832292

2284-
if not maybe_graph:
2285-
# Fallback to eager execution if graph was not used
2286-
with MoeLoadBalancerIterContext(moe_load_balancer):
2287-
outputs = self._forward_step(inputs, gather_ids,
2288-
gather_context_logits)
2289-
else:
2290-
if self.cuda_graph_runner.needs_capture(key):
2291-
2292-
def capture_forward_fn(inputs: Dict[str, Any]):
2293-
with MoeLoadBalancerIterContext(moe_load_balancer):
2294-
return self._forward_step(
2295-
inputs,
2296-
gather_ids=gather_ids,
2297-
gather_context_logits=gather_context_logits)
2293+
def capture_forward_fn(inputs: Dict[str, Any]):
2294+
with MoeLoadBalancerIterContext(moe_load_balancer):
2295+
return self._forward_step(
2296+
inputs,
2297+
gather_ids=gather_ids,
2298+
gather_context_logits=gather_context_logits)
22982299

2299-
def capture_postprocess_fn(inputs: Dict[str, Any]):
2300-
self._postprocess_inputs(inputs)
2300+
def capture_postprocess_fn(inputs: Dict[str, Any]):
2301+
self._postprocess_inputs(inputs)
23012302

2302-
self.cuda_graph_runner.capture(key, capture_forward_fn,
2303-
inputs,
2304-
capture_postprocess_fn)
2303+
self.cuda_graph_runner.capture(key, capture_forward_fn,
2304+
inputs,
2305+
capture_postprocess_fn)
23052306

2306-
# here we don't need to use context since cuda graph capture didn't run kernel.
2307-
# maybe we need a cleaner way to do this.
2308-
outputs = self.cuda_graph_runner.replay(key, inputs)
2309-
else:
2310-
with MoeLoadBalancerIterContext(moe_load_balancer):
2307+
# here we don't need to use context since cuda graph capture didn't run kernel.
2308+
# maybe we need a cleaner way to do this.
23112309
outputs = self.cuda_graph_runner.replay(key, inputs)
2310+
else:
2311+
with MoeLoadBalancerIterContext(moe_load_balancer):
2312+
outputs = self.cuda_graph_runner.replay(key, inputs)
23122313

23132314
self._execute_logit_post_processors(scheduled_requests, outputs)
23142315

0 commit comments

Comments
 (0)