@@ -1724,6 +1724,9 @@ def forward_absorption_generation(
17241724 device = q .device ,
17251725 )
17261726
1727+ # Compute helix_position_offsets for helix parallelism.
1728+ helix_position_offsets = position_ids if self .mapping .cp_size > 1 else None
1729+
17271730 rope_stream = self .aux_stream if not has_fp8_kv_cache else None
17281731 if self .k_b_proj_trans .dtype == torch .bfloat16 :
17291732 # [num_heads, num_tokens, self.qk_nope_head_dim]
@@ -1737,10 +1740,18 @@ def forward_absorption_generation(
17371740 maybe_execute_in_parallel (
17381741 lambda : torch .ops .trtllm .bmm_out (
17391742 q_nope_t , self .k_b_proj_trans .transpose (1 , 2 ), q_nope_out ),
1740- lambda : self .mqa .mla_rope_generation (
1741- fused_q , q_pe , latent_cache , attn_metadata , cu_q_seqlens ,
1742- cu_kv_seqlens , fmha_scheduler_counter , mla_bmm1_scale ,
1743- mla_bmm2_scale , quant_q_buffer ),
1743+ lambda : self .mqa .mla_rope_generation (fused_q ,
1744+ q_pe ,
1745+ latent_cache ,
1746+ attn_metadata ,
1747+ cu_q_seqlens ,
1748+ cu_kv_seqlens ,
1749+ fmha_scheduler_counter ,
1750+ mla_bmm1_scale ,
1751+ mla_bmm2_scale ,
1752+ quant_q_buffer ,
1753+ helix_position_offsets =
1754+ helix_position_offsets ),
17441755 self .ln_events [0 ],
17451756 self .ln_events [1 ],
17461757 rope_stream ,
@@ -1758,10 +1769,18 @@ def forward_absorption_generation(
17581769 q_nope_out ,
17591770 self .k_b_proj_trans_dequant ,
17601771 ),
1761- lambda : self .mqa .mla_rope_generation (
1762- fused_q , q_pe , latent_cache , attn_metadata , cu_q_seqlens ,
1763- cu_kv_seqlens , fmha_scheduler_counter , mla_bmm1_scale ,
1764- mla_bmm2_scale , quant_q_buffer ),
1772+ lambda : self .mqa .mla_rope_generation (fused_q ,
1773+ q_pe ,
1774+ latent_cache ,
1775+ attn_metadata ,
1776+ cu_q_seqlens ,
1777+ cu_kv_seqlens ,
1778+ fmha_scheduler_counter ,
1779+ mla_bmm1_scale ,
1780+ mla_bmm2_scale ,
1781+ quant_q_buffer ,
1782+ helix_position_offsets =
1783+ helix_position_offsets ),
17651784 self .ln_events [0 ],
17661785 self .ln_events [1 ],
17671786 rope_stream ,
0 commit comments