Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,56 @@ def reshape(
input: TRTTensor,
shape: Sequence[int],
) -> TRTTensor:
# Count dynamic dimensions and check for inferred dimension (-1)
num_dynamic_dims = 0
has_inferred_dim = False
inferred_dim_index = -1

# Create a mutable copy of the shape for modification
new_shape = list(shape)

# Special case: Handle dynamic shape with inferred dimension (-1)
# This is required for ops like dynamic_block_quantize_op that requires
# dimension to be known at compile time rather than runtime
for i, s in enumerate(new_shape):
if isinstance(s, TRTTensor):
num_dynamic_dims += 1
elif s == -1:
has_inferred_dim = True
inferred_dim_index = i

# Only process if we have exactly one dynamic dimension and one inferred dimension
# This is a common pattern in quantization where one dimension is dynamic
# and another needs to be inferred to maintain total element count
if has_inferred_dim and num_dynamic_dims == 1:
# Calculate the inferred dimension size
# Total elements = product of all input dimensions except dynamic shape dim
total_elements = 1
for s in input.shape:
if s != -1:
total_elements *= s

# Divide by known dimensions in new_shape to find the inferred dimension
# This ensures the total number of elements remains the same
for s in new_shape:
if isinstance(s, int) and s != -1:
if total_elements % s != 0:
raise ValueError(
f"Cannot infer dimension: {total_elements} elements not divisible by {s}"
)
total_elements //= s

# Replace -1 with the calculated inferred dimension
new_shape[inferred_dim_index] = total_elements
Comment on lines +52 to +68
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please give an example here ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peri044 I encountered an issue when using the TinyLlama/TinyLlama-1.1B-Chat-v1.0 model with nvpf4 quantization. Although the issue is still reproducible in the mini repo (#3745), it was not observed during LLM testing. I think this change is not necessary.


layer = ctx.net.add_shuffle(input)
if all(isinstance(s, int) for s in shape):
layer.reshape_dims = tuple(shape)
if all(isinstance(s, int) for s in new_shape):
layer.reshape_dims = tuple(new_shape)
else:
# Convert all the dimensions to trt Tensors.
trt_shape = []

for i, s in enumerate(shape):
for i, s in enumerate(new_shape):
if isinstance(s, TRTTensor):
dim_int32 = cast_trt_tensor(
ctx,
Expand Down
Loading