Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions tensorrt_llm/_torch/memory_buffer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def _view_as(buffer: torch.Tensor, target_shape: list[int],

def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
buffer_name: str, reserve_buffer: bool):
"""Return a reusable buffer view for the requested shape/dtype.

The returned tensor is backed by an underlying `torch.uint8` buffer. When
no suitable buffer exists in the pool, a new tensor is created via
`torch.empty`, so its contents are uninitialized. Overwrite the data before use if needed.
"""

# all buffers are allocated with 1 byte element size
required_memory_size = math.prod(tensor_shape) * dtype.itemsize
Expand Down Expand Up @@ -91,7 +97,7 @@ def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
new_buffer_tensor = None
try:
with torch.cuda.memory.use_mem_pool(get_shared_pool()):
new_buffer_tensor = torch.zeros((required_memory_size, ),
new_buffer_tensor = torch.empty((required_memory_size, ),
device='cuda',
dtype=torch.uint8)
except Exception as ex:
Expand All @@ -101,7 +107,7 @@ def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
)
# if exception happens during allocating memory from shared pool, retry
# to allocate from default pool
new_buffer_tensor = torch.zeros((required_memory_size, ),
new_buffer_tensor = torch.empty((required_memory_size, ),
device='cuda',
dtype=torch.uint8)

Expand Down
Loading