@@ -750,7 +750,8 @@ def __init__(
750750 # tensor parallel
751751 config = config or ModelConfig ()
752752 if mapping_with_cp is not None :
753- print ("[MLA::__init__] OVERRIDING MAPPING WITH CP DETECTED." )
753+ logger .warning (
754+ "[MLA::__init__] Overriding mapping with CP detected." )
754755 self .mapping = mapping_with_cp
755756 else :
756757 self .mapping = config .mapping
@@ -762,7 +763,8 @@ def __init__(
762763 if self .mapping .has_cp_ulysses ():
763764 raise NotImplementedError ("MLA doesn't support CP Ulyssees yet" )
764765 if self .mapping .cp_size > 1 :
765- assert self .mapping .cp_config ['cp_type' ] == CpType .HELIX , f"CP type must be HELIX for MLA, but got { self .mapping .cp_config ['cp_type' ]} ."
766+ assert self .mapping .cp_config [
767+ 'cp_type' ] == CpType .HELIX , f"CP type must be HELIX for MLA, but got { self .mapping .cp_config ['cp_type' ]} ."
766768
767769 mapping = Mapping (
768770 world_size = tp_size * pp_size * cp_size ,
@@ -1727,20 +1729,19 @@ def forward_absorption_generation(
17271729 maybe_execute_in_parallel (
17281730 lambda : torch .ops .trtllm .bmm_out (
17291731 q_nope_t , self .k_b_proj_trans .transpose (1 , 2 ), q_nope_out ),
1730- lambda : self .mqa .mla_rope_generation (fused_q ,
1731- q_pe ,
1732- latent_cache ,
1733- attn_metadata ,
1734- cu_q_seqlens ,
1735- cu_kv_seqlens ,
1736- fmha_scheduler_counter ,
1737- mla_bmm1_scale ,
1738- mla_bmm2_scale ,
1739- quant_q_buffer ,
1740- helix_position_offsets =
1741- helix_position_offsets ,
1742- helix_is_inactive_rank =
1743- helix_is_inactive_rank ),
1732+ lambda : self .mqa .mla_rope_generation (
1733+ fused_q ,
1734+ q_pe ,
1735+ latent_cache ,
1736+ attn_metadata ,
1737+ cu_q_seqlens ,
1738+ cu_kv_seqlens ,
1739+ fmha_scheduler_counter ,
1740+ mla_bmm1_scale ,
1741+ mla_bmm2_scale ,
1742+ quant_q_buffer ,
1743+ helix_position_offsets = helix_position_offsets ,
1744+ helix_is_inactive_rank = helix_is_inactive_rank ),
17441745 self .ln_events [0 ],
17451746 self .ln_events [1 ],
17461747 rope_stream ,
@@ -1758,20 +1759,19 @@ def forward_absorption_generation(
17581759 q_nope_out ,
17591760 self .k_b_proj_trans_dequant ,
17601761 ),
1761- lambda : self .mqa .mla_rope_generation (fused_q ,
1762- q_pe ,
1763- latent_cache ,
1764- attn_metadata ,
1765- cu_q_seqlens ,
1766- cu_kv_seqlens ,
1767- fmha_scheduler_counter ,
1768- mla_bmm1_scale ,
1769- mla_bmm2_scale ,
1770- quant_q_buffer ,
1771- helix_position_offsets =
1772- helix_position_offsets ,
1773- helix_is_inactive_rank =
1774- helix_is_inactive_rank ),
1762+ lambda : self .mqa .mla_rope_generation (
1763+ fused_q ,
1764+ q_pe ,
1765+ latent_cache ,
1766+ attn_metadata ,
1767+ cu_q_seqlens ,
1768+ cu_kv_seqlens ,
1769+ fmha_scheduler_counter ,
1770+ mla_bmm1_scale ,
1771+ mla_bmm2_scale ,
1772+ quant_q_buffer ,
1773+ helix_position_offsets = helix_position_offsets ,
1774+ helix_is_inactive_rank = helix_is_inactive_rank ),
17751775 self .ln_events [0 ],
17761776 self .ln_events [1 ],
17771777 rope_stream ,
0 commit comments