Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 69 additions & 3 deletions tensorrt_llm/_torch/memory_buffer_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import contextlib
import math
from dataclasses import dataclass
from typing import Optional

import torch

from tensorrt_llm.logger import logger


@dataclass
class BufferBlock:
Expand Down Expand Up @@ -80,9 +83,23 @@ def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,

# No suitable buffer was found, so allocate a new one.
# The new buffer is created with uint8 to represent raw bytes.
new_buffer_tensor = torch.zeros((required_memory_size, ),
device='cuda',
dtype=torch.uint8)
new_buffer_tensor = None
try:
with torch.cuda.memory.use_mem_pool(get_shared_pool()):
new_buffer_tensor = torch.zeros((required_memory_size, ),
device='cuda',
dtype=torch.uint8)
except Exception as ex:
# Need to check if this is an OOM exception
logger.debug(
f"Exception happened to create tensor from given memory pool: {str(ex)}"
)
# if exception happens during allocating memory from shared pool, retry
# to allocate from default pool
new_buffer_tensor = torch.zeros((required_memory_size, ),
device='cuda',
dtype=torch.uint8)

new_block = BufferBlock(buffer=new_buffer_tensor,
is_reserved=reserve_buffer)

Expand All @@ -97,3 +114,52 @@ def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
def get_memory_buffers():
global _buffer
return _buffer


_shared_pool = None


def set_shared_pool(shared_pool):
"""Sets the global memory pool for buffer allocation.

Args:
shared_pool: A CUDA memory pool object to use for allocations.
"""
global _shared_pool
_shared_pool = shared_pool


def get_shared_pool():
"""Retrieves the current global memory pool.

Returns:
The current memory pool, or None if not set.
"""
return _shared_pool


@contextlib.contextmanager
def with_shared_pool(shared_pool) -> contextlib.AbstractContextManager:
"""Temporarily sets a preferred memory pool and restores the previous one on exit.

This context manager allows temporarily switching to a different memory pool
for CUDA graph operations, ensuring the original pool is restored even if
an exception occurs.

Args:
shared_pool: The memory pool to use within the context.

Yields:
None

Example:
>>> with with_shared_pool(shared_pool):
... # Allocations within this block use shared_pool
... tensor = allocate_buffer(...)
"""
old_shared_pool = get_shared_pool()
set_shared_pool(shared_pool)
try:
yield
finally:
set_shared_pool(old_shared_pool)
10 changes: 10 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,15 @@ def needs_capture(self, key: Tuple[int, int, int]):

return key not in self.graph_outputs

def get_graph_pool(self):
"""Returns the CUDA memory pool used by this graph runner.

Returns:
The CUDA memory pool associated with captured graphs, or None if
no graphs have been captured yet.
"""
return self.memory_pool

def capture(self,
key: Tuple[int, int, int],
forward_fn: Callable,
Expand Down Expand Up @@ -255,6 +264,7 @@ def _setup_spec_decoding_and_forward(key: Tuple[int, int, int],
capture_inputs)
if postprocess_fn is not None:
postprocess_fn(capture_inputs)

with torch.cuda.graph(graph, pool=self.memory_pool):
output = _setup_spec_decoding_and_forward(
key, forward_fn, capture_inputs)
Expand Down
49 changes: 25 additions & 24 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ..distributed import MPIDist
from ..distributed.communicator import init_pp_comm
from ..expert_statistic import ExpertStatistic
from ..memory_buffer_utils import with_shared_pool
from ..metadata import KVCacheParams
from ..models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader
from ..models.modeling_multimodal_utils import filter_mm_token_from_input_ids
Expand Down Expand Up @@ -2209,35 +2210,35 @@ def forward(
new_tensors_device, cache_indirection_buffer)

self.iter_counter += 1
with with_shared_pool(self.cuda_graph_runner.get_graph_pool()):
if not maybe_graph:
# Fallback to eager execution if graph was not used
with MoeLoadBalancerIterContext(moe_load_balancer):
outputs = self._forward_step(inputs, gather_ids,
gather_context_logits)
else:
if self.cuda_graph_runner.needs_capture(key):

if not maybe_graph:
# Fallback to eager execution if graph was not used
with MoeLoadBalancerIterContext(moe_load_balancer):
outputs = self._forward_step(inputs, gather_ids,
gather_context_logits)
else:
if self.cuda_graph_runner.needs_capture(key):

def capture_forward_fn(inputs: Dict[str, Any]):
with MoeLoadBalancerIterContext(moe_load_balancer):
return self._forward_step(
inputs,
gather_ids=gather_ids,
gather_context_logits=gather_context_logits)
def capture_forward_fn(inputs: Dict[str, Any]):
with MoeLoadBalancerIterContext(moe_load_balancer):
return self._forward_step(
inputs,
gather_ids=gather_ids,
gather_context_logits=gather_context_logits)

def capture_postprocess_fn(inputs: Dict[str, Any]):
self._postprocess_inputs(inputs)
def capture_postprocess_fn(inputs: Dict[str, Any]):
self._postprocess_inputs(inputs)

self.cuda_graph_runner.capture(key, capture_forward_fn,
inputs,
capture_postprocess_fn)
self.cuda_graph_runner.capture(key, capture_forward_fn,
inputs,
capture_postprocess_fn)

# here we don't need to use context since cuda graph capture didn't run kernel.
# maybe we need a cleaner way to do this.
outputs = self.cuda_graph_runner.replay(key, inputs)
else:
with MoeLoadBalancerIterContext(moe_load_balancer):
# here we don't need to use context since cuda graph capture didn't run kernel.
# maybe we need a cleaner way to do this.
outputs = self.cuda_graph_runner.replay(key, inputs)
else:
with MoeLoadBalancerIterContext(moe_load_balancer):
outputs = self.cuda_graph_runner.replay(key, inputs)

self._execute_logit_post_processors(scheduled_requests, outputs)

Expand Down