@@ -51,6 +51,11 @@ def _view_as(buffer: torch.Tensor, target_shape: list[int],
5151
5252 def get_buffer (self , tensor_shape : list [int ], dtype : torch .dtype ,
5353 buffer_name : str , reserve_buffer : bool ):
54+ """Return a reusable buffer view for the requested shape/dtype.
55+ The returned tensor is backed by an underlying `torch.uint8` buffer. When
56+ no suitable buffer exists in the pool, a new tensor is created via
57+ `torch.empty`, so its contents are uninitialized. Overwrite the data before use if needed.
58+ """
5459
5560 # all buffers are allocated with 1 byte element size
5661 required_memory_size = math .prod (tensor_shape ) * dtype .itemsize
@@ -91,7 +96,7 @@ def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
9196 new_buffer_tensor = None
9297 try :
9398 with torch .cuda .memory .use_mem_pool (get_shared_pool ()):
94- new_buffer_tensor = torch .zeros ((required_memory_size , ),
99+ new_buffer_tensor = torch .empty ((required_memory_size , ),
95100 device = 'cuda' ,
96101 dtype = torch .uint8 )
97102 except Exception as ex :
@@ -101,7 +106,7 @@ def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
101106 )
102107 # if exception happens during allocating memory from shared pool, retry
103108 # to allocate from default pool
104- new_buffer_tensor = torch .zeros ((required_memory_size , ),
109+ new_buffer_tensor = torch .empty ((required_memory_size , ),
105110 device = 'cuda' ,
106111 dtype = torch .uint8 )
107112
0 commit comments