diff --git a/tensorrt_llm/_torch/memory_buffer_utils.py b/tensorrt_llm/_torch/memory_buffer_utils.py index 2491af07cda..33a8ddf756f 100644 --- a/tensorrt_llm/_torch/memory_buffer_utils.py +++ b/tensorrt_llm/_torch/memory_buffer_utils.py @@ -4,6 +4,10 @@ import torch +from tensorrt_llm.logger import logger + +from .utils import get_shared_pool + @dataclass class BufferBlock: @@ -80,9 +84,22 @@ def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype, # No suitable buffer was found, so allocate a new one. # The new buffer is created with uint8 to represent raw bytes. - new_buffer_tensor = torch.zeros((required_memory_size, ), - device='cuda', - dtype=torch.uint8) + new_buffer_tensor = None + try: + with torch.cuda.memory.use_mem_pool(get_shared_pool()): + new_buffer_tensor = torch.zeros((required_memory_size, ), + device='cuda', + dtype=torch.uint8) + except Exception as ex: + # Need to check if this is an OOM exception + logger.debug( + f"Exception happened to create tensor from given memory pool: {str{ex}}" + ) + # if exception happens during allocating memory from + new_buffer_tensor = torch.zeros((required_memory_size, ), + device='cuda', + dtype=torch.uint8) + new_block = BufferBlock(buffer=new_buffer_tensor, is_reserved=reserve_buffer) diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 4c097ac0d2a..724fc2e4833 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -194,6 +194,15 @@ def needs_capture(self, key: Tuple[int, int, int]): return key not in self.graph_outputs + def get_graph_pool(self): + """Returns the CUDA memory pool used by this graph runner. + + Returns: + The CUDA memory pool associated with captured graphs, or None if + no graphs have been captured yet. + """ + return self.memory_pool + def capture(self, key: Tuple[int, int, int], forward_fn: Callable, @@ -255,6 +264,7 @@ def _setup_spec_decoding_and_forward(key: Tuple[int, int, int], capture_inputs) if postprocess_fn is not None: postprocess_fn(capture_inputs) + with torch.cuda.graph(graph, pool=self.memory_pool): output = _setup_spec_decoding_and_forward( key, forward_fn, capture_inputs) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index a0f650f6aff..c54ea3bb264 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -48,7 +48,8 @@ from ..speculative.mtp import SampleStateTensorsMTP from ..utils import (get_model_extra_attrs, set_per_request_piecewise_cuda_graph_flag, - set_torch_compiling, with_model_extra_attrs) + set_shared_mem_pool, set_torch_compiling, + with_model_extra_attrs) from .config import PyTorchConfig from .config_utils import is_mla from .cuda_graph_runner import CUDAGraphRunner @@ -2186,35 +2187,35 @@ def forward( new_tensors_device, cache_indirection_buffer) self.iter_counter += 1 + with set_shared_mem_pool(self.cuda_graph_runner.get_graph_pool()): + if not maybe_graph: + # Fallback to eager execution if graph was not used + with MoeLoadBalancerIterContext(moe_load_balancer): + outputs = self._forward_step(inputs, gather_ids, + gather_context_logits) + else: + if self.cuda_graph_runner.needs_capture(key): - if not maybe_graph: - # Fallback to eager execution if graph was not used - with MoeLoadBalancerIterContext(moe_load_balancer): - outputs = self._forward_step(inputs, gather_ids, - gather_context_logits) - else: - if self.cuda_graph_runner.needs_capture(key): - - def capture_forward_fn(inputs: Dict[str, Any]): - with MoeLoadBalancerIterContext(moe_load_balancer): - return self._forward_step( - inputs, - gather_ids=gather_ids, - gather_context_logits=gather_context_logits) + def capture_forward_fn(inputs: Dict[str, Any]): + with MoeLoadBalancerIterContext(moe_load_balancer): + return self._forward_step( + inputs, + gather_ids=gather_ids, + gather_context_logits=gather_context_logits) - def capture_postprocess_fn(inputs: Dict[str, Any]): - self._postprocess_inputs(inputs) + def capture_postprocess_fn(inputs: Dict[str, Any]): + self._postprocess_inputs(inputs) - self.cuda_graph_runner.capture(key, capture_forward_fn, - inputs, - capture_postprocess_fn) + self.cuda_graph_runner.capture(key, capture_forward_fn, + inputs, + capture_postprocess_fn) - # here we don't need to use context since cuda graph capture didn't run kernel. - # maybe we need a cleaner way to do this. - outputs = self.cuda_graph_runner.replay(key, inputs) - else: - with MoeLoadBalancerIterContext(moe_load_balancer): + # here we don't need to use context since cuda graph capture didn't run kernel. + # maybe we need a cleaner way to do this. outputs = self.cuda_graph_runner.replay(key, inputs) + else: + with MoeLoadBalancerIterContext(moe_load_balancer): + outputs = self.cuda_graph_runner.replay(key, inputs) self._execute_logit_post_processors(scheduled_requests, outputs) diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index 35057d12103..70ed974835f 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -312,3 +312,53 @@ def create_lm_head_tp_mapping(mapping: Mapping, token_count: int) -> Mapping: # It's here so that unit tests can mock it and turn it off. def _get_allow_chain_drafter() -> bool: return True + + +_buffer_pool = None + + +def set_shared_pool(buffer_pool): + """Sets the global memory pool for buffer allocation. + + Args: + buffer_pool: A CUDA memory pool object to use for allocations. + """ + global _buffer_pool + _buffer_pool = buffer_pool + + +def get_shared_pool(): + """Retrieves the current global memory pool. + + Returns: + The current memory pool, or None if not set. + """ + global _buffer_pool + return _buffer_pool + + +@contextlib.contextmanager +def set_shared_mem_pool(mem_pool) -> contextlib.AbstractContextManager: + """Temporarily sets a preferred memory pool and restores the previous one on exit. + + This context manager allows temporarily switching to a different memory pool + for CUDA graph operations, ensuring the original pool is restored even if + an exception occurs. + + Args: + mem_pool: The memory pool to use within the context. + + Yields: + None + + Example: + >>> with set_shared_mem_pool(graph_pool): + ... # Allocations within this block use graph_pool + ... tensor = allocate_buffer(...) + """ + old_buffer_pool = get_shared_pool() + set_shared_pool(mem_pool) + try: + yield + finally: + set_mem_pool(old_buffer_pool)