Skip to content

Commit 6dd2fcd

Browse files
authored
[https://nvbugs/5629833][fix] Don't fill tensors with 0 (#9296)
Signed-off-by: Hui Gao <[email protected]>
1 parent cddc754 commit 6dd2fcd

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

tensorrt_llm/_torch/memory_buffer_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ def _view_as(buffer: torch.Tensor, target_shape: list[int],
5151

5252
def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
5353
buffer_name: str, reserve_buffer: bool):
54+
"""Return a reusable buffer view for the requested shape/dtype.
55+
The returned tensor is backed by an underlying `torch.uint8` buffer. When
56+
no suitable buffer exists in the pool, a new tensor is created via
57+
`torch.empty`, so its contents are uninitialized. Overwrite the data before use if needed.
58+
"""
5459

5560
# all buffers are allocated with 1 byte element size
5661
required_memory_size = math.prod(tensor_shape) * dtype.itemsize
@@ -91,7 +96,7 @@ def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
9196
new_buffer_tensor = None
9297
try:
9398
with torch.cuda.memory.use_mem_pool(get_shared_pool()):
94-
new_buffer_tensor = torch.zeros((required_memory_size, ),
99+
new_buffer_tensor = torch.empty((required_memory_size, ),
95100
device='cuda',
96101
dtype=torch.uint8)
97102
except Exception as ex:
@@ -101,7 +106,7 @@ def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
101106
)
102107
# if exception happens during allocating memory from shared pool, retry
103108
# to allocate from default pool
104-
new_buffer_tensor = torch.zeros((required_memory_size, ),
109+
new_buffer_tensor = torch.empty((required_memory_size, ),
105110
device='cuda',
106111
dtype=torch.uint8)
107112

0 commit comments

Comments
 (0)