Skip to content

Commit 0ee71d9

Browse files
authored
[https://nvbugs/5606166][fix] AutoDeploy: use tuples for cudagraph shape lookup (#8658)
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent a09b38a commit 0ee71d9

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def capture_graph(self, *args, **kwargs):
139139
args, kwargs = self._in_spec.unflatten(inputs_truncated + args_static)
140140

141141
# capture graph for truncated inputs
142-
combined_shape = sum((input.shape for input in inputs_truncated), start=())
142+
combined_shape = sum((tuple(input.shape) for input in inputs_truncated), start=())
143143
self.cudagraphs[combined_shape] = self._capture_one_graph(*args, **kwargs)
144144

145145
def forward(self, *args, **kwargs) -> Any:
@@ -157,7 +157,7 @@ def forward(self, *args, **kwargs) -> Any:
157157

158158
# Calculate rounded-up shapes for each input
159159
rounded_shapes = [
160-
(self.round_to_cuda_batch_size(input.shape[0]),) + input.shape[1:]
160+
(self.round_to_cuda_batch_size(input.shape[0]),) + tuple(input.shape[1:])
161161
for input in args_batched
162162
]
163163
combined_shape = sum(rounded_shapes, start=())

tests/unittest/_torch/auto_deploy/unit/singlegpu/compile/test_cuda_graph_batch_sizes.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@
3535
class TestCudaGraphBatchSizes:
3636
"""Test class for CUDA graph batch size handling."""
3737

38+
@staticmethod
39+
def _raise_error_for_forward(*args, **kwargs):
40+
raise RuntimeError("forward method should not be called")
41+
3842
@pytest.fixture
3943
def simple_model_and_inputs(self):
4044
"""Create a simple model and inputs for testing."""
@@ -192,7 +196,13 @@ def test_forward_uses_cuda_graph_for_valid_batch_sizes(self, simple_model_and_in
192196
test_input = data["input_tensor"][:batch_size]
193197

194198
with torch.inference_mode():
195-
output = captured_graph.forward(test_input)
199+
# temporarily remove model forward to ensure that the captured graph is used
200+
original_forward = captured_graph.model.forward
201+
captured_graph.model.forward = self._raise_error_for_forward
202+
try:
203+
output = captured_graph.forward(test_input)
204+
finally:
205+
captured_graph.model.forward = original_forward
196206

197207
# Should get valid output
198208
assert output is not None

0 commit comments

Comments
 (0)