Skip to content

Commit 20edf93

Browse files
committed
[https://nvbugs/5637012][fix] Fix helix unit tests
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 134446a commit 20edf93

File tree

2 files changed

+27
-9
lines changed

2 files changed

+27
-9
lines changed

tensorrt_llm/_torch/modules/attention.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1724,6 +1724,9 @@ def forward_absorption_generation(
17241724
device=q.device,
17251725
)
17261726

1727+
# Compute helix_position_offsets for helix parallelism.
1728+
helix_position_offsets = position_ids if self.mapping.cp_size > 1 else None
1729+
17271730
rope_stream = self.aux_stream if not has_fp8_kv_cache else None
17281731
if self.k_b_proj_trans.dtype == torch.bfloat16:
17291732
# [num_heads, num_tokens, self.qk_nope_head_dim]
@@ -1737,10 +1740,18 @@ def forward_absorption_generation(
17371740
maybe_execute_in_parallel(
17381741
lambda: torch.ops.trtllm.bmm_out(
17391742
q_nope_t, self.k_b_proj_trans.transpose(1, 2), q_nope_out),
1740-
lambda: self.mqa.mla_rope_generation(
1741-
fused_q, q_pe, latent_cache, attn_metadata, cu_q_seqlens,
1742-
cu_kv_seqlens, fmha_scheduler_counter, mla_bmm1_scale,
1743-
mla_bmm2_scale, quant_q_buffer),
1743+
lambda: self.mqa.mla_rope_generation(fused_q,
1744+
q_pe,
1745+
latent_cache,
1746+
attn_metadata,
1747+
cu_q_seqlens,
1748+
cu_kv_seqlens,
1749+
fmha_scheduler_counter,
1750+
mla_bmm1_scale,
1751+
mla_bmm2_scale,
1752+
quant_q_buffer,
1753+
helix_position_offsets=
1754+
helix_position_offsets),
17441755
self.ln_events[0],
17451756
self.ln_events[1],
17461757
rope_stream,
@@ -1758,10 +1769,18 @@ def forward_absorption_generation(
17581769
q_nope_out,
17591770
self.k_b_proj_trans_dequant,
17601771
),
1761-
lambda: self.mqa.mla_rope_generation(
1762-
fused_q, q_pe, latent_cache, attn_metadata, cu_q_seqlens,
1763-
cu_kv_seqlens, fmha_scheduler_counter, mla_bmm1_scale,
1764-
mla_bmm2_scale, quant_q_buffer),
1772+
lambda: self.mqa.mla_rope_generation(fused_q,
1773+
q_pe,
1774+
latent_cache,
1775+
attn_metadata,
1776+
cu_q_seqlens,
1777+
cu_kv_seqlens,
1778+
fmha_scheduler_counter,
1779+
mla_bmm1_scale,
1780+
mla_bmm2_scale,
1781+
quant_q_buffer,
1782+
helix_position_offsets=
1783+
helix_position_offsets),
17651784
self.ln_events[0],
17661785
self.ln_events[1],
17671786
rope_stream,

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,6 @@ accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_fp8[throughput_laten
349349
test_e2e.py::test_openai_chat_multimodal_example SKIP (https://nvbugs/5636894)
350350
accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_attention_dp] SKIP (https://nvbugs/5637220)
351351
llmapi/test_llm_examples.py::test_llmapi_example_multilora SKIP (https://nvbugs/5636857)
352-
unittest/_torch/modules/test_mla_helix.py::test_mla_helix_distributed SKIP (https://nvbugspro.nvidia.com/bug/5637012)
353352
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass] SKIP (https://nvbugs/5636916)
354353
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=TRTLLM-torch_compile=False] SKIP (https://nvbugs/5616182)
355354
examples/test_phi.py::test_llm_phi_quantization_1gpu[Phi-3-small-128k-instruct-fp8-bfloat16] SKIP (https://nvbugs/5465143)

0 commit comments

Comments
 (0)