diff --git a/cpp/tensorrt_llm/kernels/mlaKernels.cu b/cpp/tensorrt_llm/kernels/mlaKernels.cu index 02cf4eab83c..d5456b27ae9 100644 --- a/cpp/tensorrt_llm/kernels/mlaKernels.cu +++ b/cpp/tensorrt_llm/kernels/mlaKernels.cu @@ -365,7 +365,8 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe, int* seqQOffset, uint32_t* fmha_tile_counter, int32_t const* kv_cache_lengths, int* seqKVOffsets, int q_pe_ld, int q_pe_stride, KvCacheDataType cache_type, float* bmm1_scale, float* bmm2_scale, float const* quant_scale_o, float const* quant_scale_q, float const* quant_scale_kv, float const* dequant_scale_q, - float const* dequant_scale_kv, float host_bmm1_scale, int32_t const* helix_position_offsets) + float const* dequant_scale_kv, float host_bmm1_scale, int32_t const* helix_position_offsets, + bool const* helix_is_inactive_rank) { // Constants. @@ -474,7 +475,7 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe, if (valid_token) { - if (head_idx == head_num) + if (head_idx == head_num && (helix_is_inactive_rank == nullptr || !helix_is_inactive_rank[batch_idx])) { auto const token_kv_idx = kv_cache_lengths[batch_idx] - seq_len + local_token_idx; @@ -528,7 +529,7 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe, auto local_token_idx = global_token_idx % seq_len; bool valid_token = global_token_idx < total_s_len; - if (valid_token) + if (valid_token && (helix_is_inactive_rank == nullptr || !helix_is_inactive_rank[batch_idx])) { if (head_dim_vec_idx == 0) { @@ -1061,7 +1062,7 @@ void invokeMLARopeGeneration(MlaParams& params, KVCacheBuffer kv_cache_buffer params.seqQOffset, params.fmha_tile_counter, params.cache_seq_lens, params.cu_kv_seqlens, params.q_pe_ld, params.q_pe_stride, params.cache_type, params.bmm1_scale, params.bmm2_scale, params.quant_scale_o, params.quant_scale_q, params.quant_scale_kv, params.dequant_scale_q, params.dequant_scale_kv, - params.host_bmm1_scale, params.helix_position_offsets); + params.host_bmm1_scale, params.helix_position_offsets, params.helix_is_inactive_rank); } template diff --git a/cpp/tensorrt_llm/kernels/mlaKernels.h b/cpp/tensorrt_llm/kernels/mlaKernels.h index 1775b992cc7..ce6f4b1bfa0 100644 --- a/cpp/tensorrt_llm/kernels/mlaKernels.h +++ b/cpp/tensorrt_llm/kernels/mlaKernels.h @@ -107,6 +107,10 @@ struct MlaParams // for Helix parallelism: the rotary position offsets [b] int32_t const* helix_position_offsets{nullptr}; + + // for Helix parallelism: whether the current rank is inactive, shape [b] + // (the current query tokens are not appended to this rank's KV cache) + bool const* helix_is_inactive_rank{nullptr}; }; template diff --git a/cpp/tensorrt_llm/thop/attentionOp.cpp b/cpp/tensorrt_llm/thop/attentionOp.cpp index b6fd27ea744..cbb498fcf88 100644 --- a/cpp/tensorrt_llm/thop/attentionOp.cpp +++ b/cpp/tensorrt_llm/thop/attentionOp.cpp @@ -181,8 +181,8 @@ class Runner : public RunnerBase [[maybe_unused]] MlaParams mla_params; if (op.isMLAEnabled()) { - TORCH_CHECK(mla_tensor_params.size() == 1, - "Expecting 1 tensor for custom MLA tensor params: helix_position_offsets."); + TORCH_CHECK(mla_tensor_params.size() == 2, + "Expecting 2 tensors for custom MLA tensor params: helix_position_offsets and helix_is_inactive_rank."); if (is_context && op.mUseSparseAttention) { if (latent_cache.has_value()) @@ -227,10 +227,15 @@ class Runner : public RunnerBase // For generation, helix position is in ropeOp auto& mla_helix_position_offsets = mla_tensor_params[0]; + auto& mla_helix_is_inactive_rank = mla_tensor_params[1]; if (mla_helix_position_offsets.has_value()) { mla_params.helix_position_offsets = mla_helix_position_offsets->data_ptr(); } + if (mla_helix_is_inactive_rank.has_value()) + { + mla_params.helix_is_inactive_rank = mla_helix_is_inactive_rank->data_ptr(); + } } else { diff --git a/cpp/tensorrt_llm/thop/dsv3RopeOp.cpp b/cpp/tensorrt_llm/thop/dsv3RopeOp.cpp index 7bd1f2f1362..39657c71e75 100644 --- a/cpp/tensorrt_llm/thop/dsv3RopeOp.cpp +++ b/cpp/tensorrt_llm/thop/dsv3RopeOp.cpp @@ -66,6 +66,7 @@ struct MlaRopeGenArgs float const* kv_scale_quant_orig_ptr; float host_bmm1_scale; int32_t const* helix_position_offsets_ptr; + bool const* helix_is_inactive_rank_ptr; }; template @@ -105,6 +106,7 @@ void invokeMLARopeGenerationHelper(T const* latent_cache_ptr, T* q_pe_ptr, T* fu mla_params.dequant_scale_kv = args.kv_scale_quant_orig_ptr; mla_params.host_bmm1_scale = args.host_bmm1_scale; mla_params.helix_position_offsets = args.helix_position_offsets_ptr; + mla_params.helix_is_inactive_rank = args.helix_is_inactive_rank_ptr; tk::invokeMLARopeGeneration(mla_params, kv_cache_buffer, stream); } @@ -133,8 +135,8 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim + TLLM_CHECK_WITH_INFO( head_size == kv_lora_rank + qk_rope_head_dim, "head_size must = kv_lora_rank + qk_rope_head_dim"); TLLM_CHECK_WITH_INFO(num_kv_heads == 1, "num_kv_heads must = 1"); - TORCH_CHECK( - mla_tensor_params.size() == 1, "Expecting 1 tensor for custom MLA tensor params: helix_position_offsets."); + TORCH_CHECK(mla_tensor_params.size() == 2, + "Expecting 2 tensors for custom MLA tensor params: helix_position_offsets and helix_is_inactive_rank."); auto stream = at::cuda::getCurrentCUDAStream(fused_q.get_device()); auto const kv_cache_quant_mode = tc::QuantMode(uint32_t(quant_mode)); @@ -153,6 +155,7 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim + int32_t const num_gen_tokens = num_tokens; int32_t const seq_offset = num_contexts; auto& mla_helix_position_offsets = mla_tensor_params[0]; + auto& mla_helix_is_inactive_rank = mla_tensor_params[1]; int32_t const layer_num = host_kv_cache_pool_mapping.value().size(0); tk::MlaMetaParams mla_meta_params = {static_cast(q_lora_rank), static_cast(kv_lora_rank), @@ -161,6 +164,8 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim + int32_t const* helix_position_offsets_ptr = mla_helix_position_offsets.has_value() ? mla_helix_position_offsets->data_ptr() : nullptr; + bool const* helix_is_inactive_rank_ptr + = mla_helix_is_inactive_rank.has_value() ? mla_helix_is_inactive_rank->data_ptr() : nullptr; int* cu_q_seqlens_ptr = reinterpret_cast(cu_q_seqlens.data_ptr()); int* cu_kv_seqlens_ptr = reinterpret_cast(cu_kv_seqlens.data_ptr()); @@ -274,7 +279,7 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim + static_cast(num_heads), mla_meta_params, sequence_lengths_ptr, max_context_q_len, block_ids_per_seq_ptr, cache_type, cu_q_seqlens_ptr, cu_kv_seqlens_ptr, fmha_tile_counter_ptr, mla_bmm1_scale_ptr, mla_bmm2_scale_ptr, quant_q_buffer_ptr, quant_scale_o_ptr, kv_scale_orig_quant_ptr, - kv_scale_quant_orig_ptr, host_bmm1_scale, helix_position_offsets_ptr}; + kv_scale_quant_orig_ptr, host_bmm1_scale, helix_position_offsets_ptr, helix_is_inactive_rank_ptr}; auto const input_dtype = fused_q.scalar_type(); if (input_dtype == torch::kFloat16) diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index d71c189e794..d754eb701a8 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -189,7 +189,6 @@ def plan( q_pe: Optional[torch.Tensor] = None, mrope_config: Optional[dict] = None, softmax_stats_tensor: Optional[torch.Tensor] = None, - helix_position_offsets: Optional[torch.Tensor] = None, is_spec_decoding_enabled: bool = False, use_spec_decoding: bool = False, is_spec_dec_tree: bool = False, @@ -207,6 +206,8 @@ def plan( sparse_attn_offsets: Optional[torch.Tensor] = None, sparse_attn_indices_block_size: int = 1, sparse_mla_topk: int = 0, + helix_position_offsets: Optional[torch.Tensor] = None, + helix_is_inactive_rank: Optional[torch.Tensor] = None, **kwargs, ): """ @@ -243,7 +244,6 @@ def plan( use_paged_context_fmha (bool): Sets the mPagedContextFMHA attribute in the op runner. mrope_config (dict): The dictionary containing the mRope configuration. softmax_stats_tensor (torch.Tensor): The tensor to store the softmax statistics (max/sum) - helix_position_offsets (torch.Tensor): The tensor to store the helix position offsets, with shape (num_tokens) on GPU. attention_sinks (torch.Tensor): The attention sinks (additional value in the denominator of the softmax) with shape of (num_heads_q) on GPU. chunked_prefill_buffer_batch_size (int): used for malloc buffer for k and v in fp8 context mla. the max input kv length is not max_num_tokens in this case. It is chunked_prefill_buffer_batch_size * max_num_tokens. sparse_kv_indices (torch.Tensor): The sparse indices for the KV cache, with shape of (num_heads_kv, num_sparse_tokens) on GPU. @@ -252,6 +252,8 @@ def plan( sparse_attn_offsets (torch.Tensor): The batch offsets for the sparse attention indices, with shape of (num_generations + 1) on GPU. sparse_attn_indices_block_size (int): The granularity of the sparse attention indices, used by block sparse attention. sparse_mla_topk (int): The topk for the sparse MLA, used by DSA attention. + helix_position_offsets (torch.Tensor): The tensor to store the helix position offsets, with shape (num_tokens) on GPU. + helix_is_inactive_rank (torch.Tensor): For Helix: whether the current rank is inactive, with shape (batch_size) on GPU. """ self.layer_idx = layer_idx self.tokens_per_block = tokens_per_block @@ -287,7 +289,6 @@ def plan( 'mrope_position_deltas') if mrope_config is not None else None self.block_ids_per_seq = block_ids_per_seq self.softmax_stats_tensor = softmax_stats_tensor - self.helix_position_offsets = helix_position_offsets self.attention_sinks = attention_sinks self.sparse_kv_indices = sparse_kv_indices self.sparse_kv_offsets = sparse_kv_offsets @@ -295,6 +296,13 @@ def plan( self.sparse_attn_offsets = sparse_attn_offsets self.sparse_attn_indices_block_size = sparse_attn_indices_block_size self.sparse_mla_topk = sparse_mla_topk + self.helix_position_offsets = helix_position_offsets + self.helix_is_inactive_rank = helix_is_inactive_rank + if self.helix_is_inactive_rank is not None and not isinstance( + self.helix_is_inactive_rank, torch.Tensor): + self.helix_is_inactive_rank = torch.tensor( + self.helix_is_inactive_rank, dtype=torch.bool, pin_memory=True) + if max_sequence_length > self.rope_params.max_positions: self.rope_params.max_positions = max_sequence_length self.rotary_inv_freq, self.rotary_cos_sin = self.rope_params.create_rope_const_params( @@ -473,7 +481,9 @@ def run( spec_decoding_tensor_params.append(self.spec_decoding_bl_tree_mask) spec_decoding_tensor_params.append( self.spec_bl_tree_first_sparse_mask_offset_kv) - mla_tensor_params = [self.helix_position_offsets] + mla_tensor_params = [ + self.helix_position_offsets, self.helix_is_inactive_rank + ] thop.attention( q, @@ -633,6 +643,13 @@ class TrtllmAttentionMetadata(AttentionMetadata): spec_decoding_bl_tree_mask: Optional[torch.Tensor] = None spec_bl_tree_first_sparse_mask_offset_kv: Optional[torch.Tensor] = None + # Whether the current rank is inactive for helix parallelism. + # In helix parallelism, only the active rank appends KV cache for the query token + # and attends to the previously cached tokens as well as the query token. Inactive + # ranks do not append KV cache for the query token and attend to the previously + # cached tokens only. + helix_is_inactive_rank: Optional[torch.Tensor] = None + @property def max_seq_len(self) -> int: """ @@ -849,7 +866,21 @@ def prepare(self) -> None: if self.enable_flash_mla: self.prepare_flash_mla() # number of tokens needed in the kv cache for each sequence after the next pass - kv_lens = cached_token_lens + self.seq_lens_kv if cached_token_lens is not None else self.seq_lens_kv + if self.helix_is_inactive_rank is not None and len( + self.helix_is_inactive_rank): + # If helix is inactive, attend to the previously cached tokens only. + assert cached_token_lens is not None, "cached_token_lens should be set for helix" + kv_lens = cached_token_lens.clone() + helix_is_inactive_rank_cpu = torch.tensor( + self.helix_is_inactive_rank, + dtype=torch.bool, + device='cpu', + ) + active_rank = ~helix_is_inactive_rank_cpu + kv_lens[active_rank] += self.seq_lens_kv[active_rank] + else: + kv_lens = cached_token_lens + self.seq_lens_kv if cached_token_lens is not None else self.seq_lens_kv + # self.kv_lens is the valid kv cache length, while the self.kv_lens_cuda is # the sequence length including the cached tokens and the input tokens. self.kv_lens[:self.num_seqs].copy_( @@ -1537,7 +1568,6 @@ def forward( q_pe=q_pe, mrope_config=mrope_config, softmax_stats_tensor=softmax_stats_tensor, - helix_position_offsets=helix_position_offsets, is_spec_decoding_enabled=metadata.is_spec_decoding_enabled, use_spec_decoding=metadata.use_spec_decoding, is_spec_dec_tree=metadata.is_spec_dec_tree, @@ -1560,6 +1590,8 @@ def forward( sparse_attn_indices_block_size=sparse_attn_indices_block_size, sparse_mla_topk=metadata.sparse_mla_topk if hasattr( metadata, 'sparse_mla_topk') else 0, + helix_position_offsets=helix_position_offsets, + helix_is_inactive_rank=metadata.helix_is_inactive_rank, ) out_dtype = None if out_scale is not None: @@ -1819,6 +1851,7 @@ def mla_rope_generation( mla_bmm2_scale: torch.Tensor, quant_q_buffer: torch.Tensor, helix_position_offsets: Optional[torch.Tensor] = None, + helix_is_inactive_rank: Optional[torch.Tensor] = None, out_scale: Optional[torch.Tensor] = None, ) -> None: """ @@ -1838,7 +1871,19 @@ def mla_rope_generation( assert self.is_mla_enable and self.mla_params is not None assert metadata.kv_cache_manager is not None sink_token_length = 0 - mla_tensor_params = [helix_position_offsets] + + # Ensure helix_is_inactive_rank is on the same device as other tensors. + if helix_is_inactive_rank is not None: + if isinstance(helix_is_inactive_rank, list): + helix_is_inactive_rank = torch.tensor( + helix_is_inactive_rank, + dtype=torch.bool, + device=helix_position_offsets.device) + elif helix_is_inactive_rank.device.type != 'cuda': + helix_is_inactive_rank = helix_is_inactive_rank.to( + helix_position_offsets.device) + + mla_tensor_params = [helix_position_offsets, helix_is_inactive_rank] torch.ops.trtllm.mla_rope_generation( fused_q, diff --git a/tensorrt_llm/_torch/distributed/communicator.py b/tensorrt_llm/_torch/distributed/communicator.py index f0bb67fbeb2..3e8d1779679 100644 --- a/tensorrt_llm/_torch/distributed/communicator.py +++ b/tensorrt_llm/_torch/distributed/communicator.py @@ -1,3 +1,4 @@ +import copy import math import pickle # nosec B403 from abc import ABC, abstractmethod @@ -341,9 +342,24 @@ class MPIDist(Distributed): def __init__(self, mapping: Mapping): super().__init__(mapping) + self.create_cp_comm() + # Repurpose CP ranks to TP for Helix so that the right comms are created. + mapping_with_cp = None + if self.mapping.has_cp_helix(): + logger.info( + f"[MPIDist::__init__] Repurposing CP ranks to TP for Helix.") + mapping_with_cp = copy.deepcopy(self.mapping) + self.mapping = self.mapping.repurpose_helix_cp_to_tp() + self.create_tp_comm() self.create_pp_comm() - self.create_cp_comm() + + # Restore the original mapping. + if mapping_with_cp is not None: + logger.info( + f"[MPIDist::__init__] Restoring original mapping undoing Helix manipulation." + ) + self.mapping = mapping_with_cp def broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024): comm = mpi_comm() diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index a81a5141fa3..419dc8d70d6 100755 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -258,6 +258,13 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, [local_num_heads, local_qk_nope_head_dim, -1]) v_b_proj = v_b_proj.view([local_num_heads, local_v_head_dim, -1]) + if cp_size > 1: + local_cp_heads = local_num_heads // cp_size + k_b_proj = k_b_proj[cp_rank * local_cp_heads:(cp_rank + 1) * + local_cp_heads] + v_b_proj = v_b_proj[cp_rank * local_cp_heads:(cp_rank + 1) * + local_cp_heads] + return k_b_proj, v_b_proj is_lite = self.config.q_lora_rank is None @@ -268,6 +275,8 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, tp_rank = self.model_config.mapping.tp_rank tp_size = self.model_config.mapping.tp_size + cp_rank = self.model_config.mapping.cp_rank + cp_size = self.model_config.mapping.cp_size params_map = {'gate_up_proj': ['gate_proj', 'up_proj']} all_named_modules = dict(self.model.named_modules()) @@ -511,6 +520,7 @@ def __init__( model_config: ModelConfig[PretrainedConfig], layer_idx: Optional[int] = None, aux_stream: Optional[torch.cuda.Stream] = None, + mapping_with_cp: Optional[Mapping] = None, ): config = model_config.pretrained_config predicted_tokens_per_seq = model_config.spec_config.max_total_draft_tokens + 1 if model_config.spec_config is not None else 1 @@ -533,7 +543,8 @@ def __init__( layer_idx=layer_idx, dtype=config.torch_dtype, config=model_config, - aux_stream=aux_stream) + aux_stream=aux_stream, + mapping_with_cp=mapping_with_cp) self.kv_a_proj_with_mqa = DeepseekV3Linear( config.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim + @@ -1010,7 +1021,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], layer_idx: int, aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream], - is_separate_draft_engine: bool = False): + is_separate_draft_engine: bool = False, + mapping_with_cp: Optional[Mapping] = None): super().__init__() self.model_config = model_config self.config = model_config.pretrained_config @@ -1038,7 +1050,8 @@ def __init__(self, self.self_attn = DeepseekV3Attention( model_config, layer_idx=layer_idx_for_attention, - aux_stream=aux_stream_dict[AuxStreamType.Attention]) + aux_stream=aux_stream_dict[AuxStreamType.Attention], + mapping_with_cp=mapping_with_cp) self.enable_attention_dp = mapping.enable_attention_dp self.mlp_tp_size = mapping.tp_size @@ -1495,7 +1508,9 @@ def norm_hidden(): class DeepseekV3Model(DecoderModel): - def __init__(self, model_config: ModelConfig[PretrainedConfig]): + def __init__(self, + model_config: ModelConfig[PretrainedConfig], + mapping_with_cp: Optional[Mapping] = None): super().__init__(model_config) config = model_config.pretrained_config self.vocab_size = config.vocab_size @@ -1515,8 +1530,10 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): ) self.layers = nn.ModuleList([ - DeepseekV3DecoderLayer(model_config, layer_idx, - self.aux_stream_dict) + DeepseekV3DecoderLayer(model_config, + layer_idx, + self.aux_stream_dict, + mapping_with_cp=mapping_with_cp) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(hidden_size=config.hidden_size, @@ -1543,7 +1560,8 @@ def forward( hidden_states = inputs_embeds residual = None - for decoder_layer in self.layers[:self.num_hidden_layers]: + for idx, decoder_layer in enumerate( + self.layers[:self.num_hidden_layers]): hidden_states, residual = decoder_layer( position_ids=position_ids, hidden_states=hidden_states, @@ -1561,6 +1579,23 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model, PretrainedConfig]): def __init__(self, model_config: ModelConfig[PretrainedConfig]): + self.mapping_with_cp = None + # Note: Currently the usage of mapping is all over the place making its usage brittle + # in this file. As a temporary WAR, we hold on to an original copy of mapping when CP + # is in action. This shall be passed on to attention which is the only layer that's + # affected by CP. For other layers, CP ranks are repurposed to TP. This shall be undone + # at the end of __init__. + if model_config.mapping.has_cp_helix(): + print( + f"[DeepseekV3ForCausalLM::__init__] Repurposing KVP ranks to TP while keeping other details the same." + ) + self.mapping_with_cp = copy.deepcopy(model_config.mapping) + # Repurpose KVP ranks to TP while keeping other details the same. + model_config._frozen = False + model_config.mapping = model_config.mapping.repurpose_helix_cp_to_tp( + ) + model_config._frozen = True + # Rename some keys of quant_config_dict to support legacy checkpoints if model_config.quant_config_dict is not None: model_config = copy.deepcopy(model_config) @@ -1574,7 +1609,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): model_config.quant_config_dict = quant_config_dict model_config._frozen = True - super().__init__(model=DeepseekV3Model(model_config), + super().__init__(model=DeepseekV3Model( + model_config, mapping_with_cp=self.mapping_with_cp), model_config=model_config) self.model_nextn = 0 @@ -1608,6 +1644,15 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): self.epilogue.extend(self.draft_model.mtp_layers) self.epilogue.append(self.spec_worker) + # Undo any manipulations done to mapping. + if self.mapping_with_cp is not None: + print( + f"[DeepseekV3ForCausalLM::__init__] Restoring original mapping." + ) + model_config._frozen = False + model_config.mapping = self.mapping_with_cp + model_config._frozen = True + def forward( self, attn_metadata: AttentionMetadata, diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 8aa0a46b98e..adc47253939 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -686,6 +686,7 @@ def __init__( dense_bias: Optional[bool] = None, config: Optional[ModelConfig] = None, enable_unit_test: bool = False, + mapping_with_cp: Optional[Mapping] = None, ): """ Initialize the MLA module. @@ -757,7 +758,12 @@ def __init__( # tensor parallel config = config or ModelConfig() - self.mapping = config.mapping + if mapping_with_cp is not None: + logger.warning( + "[MLA::__init__] Overriding mapping with CP detected.") + self.mapping = mapping_with_cp + else: + self.mapping = config.mapping tp_size = self.mapping.tp_size pp_size = self.mapping.pp_size cp_size = self.mapping.cp_size @@ -765,6 +771,9 @@ def __init__( tp_size = 1 if self.mapping.has_cp_ulysses(): raise NotImplementedError("MLA doesn't support CP Ulyssees yet") + if self.mapping.cp_size > 1: + assert self.mapping.has_cp_helix( + ), f"CP type must be HELIX for MLA, but got {self.mapping.cp_config['cp_type']}." mapping = Mapping( world_size=tp_size * pp_size * cp_size, @@ -1054,7 +1063,7 @@ def _attn_forward_gen(self, attn_backend: AttentionBackend, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, position_ids: Optional[torch.Tensor], attn_metadata: AttentionMetadata, **kwargs): - if self.mapping.cp_size > 1: + if self.mapping.has_cp_helix(): # partial_o: [num_tokens, num_heads_tp * kv_lora_rank] # softmax_stats: [num_tokens, num_heads_tp, 2] softmax_stats = torch.empty((q.shape[0], self.num_heads_tp, 2), @@ -1326,7 +1335,8 @@ def forward_context_default( self.qk_rope_head_dim) k = k.view(-1, self.num_heads_tp * self.qk_head_dim) - helix_position_offsets = position_ids if self.mapping.cp_size > 1 else None + helix_position_offsets = position_ids if self.mapping.has_cp_helix( + ) else None attn_output = self.mha.forward( q, @@ -1706,8 +1716,11 @@ def forward_absorption_generation( device=q.device, ) - # Compute helix_position_offsets for helix parallelism. - helix_position_offsets = position_ids if self.mapping.cp_size > 1 else None + helix_position_offsets, helix_is_inactive_rank = None, None + if self.mapping.has_cp_helix(): + helix_position_offsets = position_ids + helix_is_inactive_rank = attn_metadata.helix_is_inactive_rank + assert helix_position_offsets is not None and helix_is_inactive_rank is not None, "helix_position_offsets and helix_is_inactive_rank must be provided for helix parallelism." rope_stream = self.aux_stream if not has_fp8_kv_cache else None if self.k_b_proj_trans.dtype == torch.bfloat16: @@ -1722,18 +1735,19 @@ def forward_absorption_generation( maybe_execute_in_parallel( lambda: torch.ops.trtllm.bmm_out( q_nope_t, self.k_b_proj_trans.transpose(1, 2), q_nope_out), - lambda: self.mqa.mla_rope_generation(fused_q, - q_pe, - latent_cache, - attn_metadata, - cu_q_seqlens, - cu_kv_seqlens, - fmha_scheduler_counter, - mla_bmm1_scale, - mla_bmm2_scale, - quant_q_buffer, - helix_position_offsets= - helix_position_offsets), + lambda: self.mqa.mla_rope_generation( + fused_q, + q_pe, + latent_cache, + attn_metadata, + cu_q_seqlens, + cu_kv_seqlens, + fmha_scheduler_counter, + mla_bmm1_scale, + mla_bmm2_scale, + quant_q_buffer, + helix_position_offsets=helix_position_offsets, + helix_is_inactive_rank=helix_is_inactive_rank), self.ln_events[0], self.ln_events[1], rope_stream, @@ -1751,18 +1765,19 @@ def forward_absorption_generation( q_nope_out, self.k_b_proj_trans_dequant, ), - lambda: self.mqa.mla_rope_generation(fused_q, - q_pe, - latent_cache, - attn_metadata, - cu_q_seqlens, - cu_kv_seqlens, - fmha_scheduler_counter, - mla_bmm1_scale, - mla_bmm2_scale, - quant_q_buffer, - helix_position_offsets= - helix_position_offsets), + lambda: self.mqa.mla_rope_generation( + fused_q, + q_pe, + latent_cache, + attn_metadata, + cu_q_seqlens, + cu_kv_seqlens, + fmha_scheduler_counter, + mla_bmm1_scale, + mla_bmm2_scale, + quant_q_buffer, + helix_position_offsets=helix_position_offsets, + helix_is_inactive_rank=helix_is_inactive_rank), self.ln_events[0], self.ln_events[1], rope_stream, @@ -2084,7 +2099,6 @@ def forward_sparse_mla_kvcache_bf16( else: raise NotImplementedError( f"Missing bmm impl for dtype: {self.v_b_proj.dtype}.") - return output def forward( @@ -2115,7 +2129,7 @@ def forward( output=attn_output, latent_cache_gen=latent_cache_gen) - if self.enable_unit_test and self.mapping.cp_size > 1: + if self.enable_unit_test and self.mapping.has_cp_helix(): # note: for allowing testing Helix parallelism, we ensure that # the output is compatible with o_proj even in the context phase, # thus we cut it to num_heads_tp_cp * v_head_dim diff --git a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py index 17af85e755b..2cbf5635a07 100644 --- a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py +++ b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py @@ -313,11 +313,11 @@ def _fetch_and_process_requests( new_requests = self._handle_special_queue_items(new_requests) # Attach Python objects to requests - if py_request_objects and (self.dist.tp_size > 1 - or self.dist.has_pp) and self.dist.rank > 0: + if py_request_objects and (self.dist.tp_size > 1 or self.dist.has_pp + or self.dist.cp_size + > 1) and self.dist.rank > 0: self._attach_py_objects_to_requests(new_requests, py_request_objects) - self.waiting_queue.extend(new_requests) new_requests = self._get_from_waiting_queue( @@ -693,6 +693,7 @@ def _merge_helix_requests(self, new_requests: list[RequestQueueItem], input_token_ids=input_ids_this_rank, position_ids=position_ids_this_rank, ) + req.total_input_len_cp = input_len req_with_children.append(req) if req.child_requests: req_with_children.extend(req.child_requests) @@ -707,7 +708,6 @@ def _merge_requests( if cp_type == CpType.STAR: return self._merge_star_attention_requests(new_requests) elif cp_type == CpType.HELIX: - # Take the usual route below. return self._merge_helix_requests( new_requests, tokens_per_block=cp_config['tokens_per_block']) diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 01d3f35f876..28314382564 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -488,6 +488,7 @@ def __init__( self.py_orig_prompt_len = self.orig_prompt_len self.py_max_new_tokens = self.max_new_tokens self.py_min_length = self.sampling_config.min_length + self.py_helix_is_inactive_rank = False self.py_batch_idx = None self.py_draft_pages_allocated = 0 self.py_rewind_len = 0 diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 21c90738b09..aef7822eb03 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -567,12 +567,13 @@ def warmup(self, resource_manager: ResourceManager) -> None: # Reset the global cuda graph dummy request to None in warmup. self.cuda_graph_runner.padding_dummy_request = None - # TODO: current warmup_request is not suitable for context parallelism. cp_type = self.mapping.cp_config.get('cp_type', None) if cp_type is not None: - logger.info("[ModelEngine::warmup] Skipping warmup for cp_type: ", - cp_type.name) - return + if cp_type in [CpType.ULYSSES, CpType.STAR]: + logger.info( + "[ModelEngine::warmup] Skipping warmup for cp_type: ", + cp_type.name) + return self._run_torch_compile_warmup(resource_manager) self._run_autotuner_warmup(resource_manager) @@ -1617,6 +1618,7 @@ def _prepare_tp_inputs( # update batch index request.py_batch_idx = request.py_seq_slot + helix_is_inactive_rank = [] if self.mapping.has_cp_helix() else None for request in generation_requests: request_ids.append(request.py_request_id) beam_width = request.sampling_config.beam_width @@ -1649,16 +1651,26 @@ def _prepare_tp_inputs( if beam == first_beam: previous_batch_indices.append(request.py_batch_idx) past_seen_token_num = request.max_beam_num_tokens + position_id = past_seen_token_num if self.mapping.has_cp_helix(): - # Do an allgather among CP ranks to get the complete sequence length seen by all CP ranks. - past_seen_token_nums = self.dist.cp_allgather( - past_seen_token_num) - position_id = sum(past_seen_token_nums) + # Warmup doesn't have `total_input_len_cp` set because merge_helix_requests is not called. + if not self.is_warmup and not request.is_cuda_graph_dummy: + position_id = request.total_input_len_cp + request.py_decoding_iter - 1 + # TODO: [TRTLLM-5972] Lift the limitation that last rank is always the active one for helix. + if self.mapping.cp_rank == self.mapping.cp_size - 1: + past_seen_token_num = request.orig_prompt_len + request.py_decoding_iter - 1 + else: + # past_seen_token_num doesn't grow on inactive ranks. + past_seen_token_num = request.orig_prompt_len + position_ids.append(position_id) num_cached_tokens_per_seq.append(past_seen_token_num) request.cached_tokens = num_cached_tokens_per_seq[-1] prompt_lengths.append(request.py_prompt_len) + if self.mapping.has_cp_helix(): + helix_is_inactive_rank.append( + request.py_helix_is_inactive_rank) draft_lens.append(0) sequence_lengths.append(1) num_accepted_draft_tokens.append(0) @@ -1988,6 +2000,7 @@ def previous_seq_slots_device(): attn_metadata.request_ids = request_ids attn_metadata.prompt_lens = prompt_lengths + attn_metadata.helix_is_inactive_rank = helix_is_inactive_rank attn_metadata.num_contexts = len(scheduled_requests.context_requests) # Use num_chunked_ctx_requests to record the number of extend context requests, # so that we can update the kv_lens_cuda correctly in _preprocess_inputs. diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index c2f8f8175e2..6e8eee8efdd 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -464,6 +464,13 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): req, block_ids) for req in generation_batch: + # TODO: [TRTLLM-5972] Lift the limitation that last rank is always the active one for helix. + if self.mapping.has_cp_helix(): + if self.mapping.cp_rank != self.mapping.cp_size - 1: + req.py_helix_is_inactive_rank = True + # Skip allocating KV cache at decode for inactive helix ranks. + if req.py_helix_is_inactive_rank: + continue self.impl.add_token(req.py_request_id) for _ in range(get_draft_token_length(req)): self.impl.add_token(req.py_request_id) diff --git a/tensorrt_llm/bench/benchmark/low_latency.py b/tensorrt_llm/bench/benchmark/low_latency.py index 63313427d4d..afd4125de94 100644 --- a/tensorrt_llm/bench/benchmark/low_latency.py +++ b/tensorrt_llm/bench/benchmark/low_latency.py @@ -121,6 +121,12 @@ default=1, help="pipeline parallelism size", ) +@optgroup.option( + "--cp", + type=int, + default=1, + help="context parallelism size", +) @optgroup.option( "--ep", type=int, diff --git a/tensorrt_llm/bench/benchmark/throughput.py b/tensorrt_llm/bench/benchmark/throughput.py index 171b14ada76..25ad2635709 100755 --- a/tensorrt_llm/bench/benchmark/throughput.py +++ b/tensorrt_llm/bench/benchmark/throughput.py @@ -202,6 +202,12 @@ default=1, help="pipeline parallelism size", ) +@optgroup.option( + "--cp", + type=int, + default=1, + help="context parallelism size", +) @optgroup.option( "--ep", type=int, diff --git a/tensorrt_llm/commands/eval.py b/tensorrt_llm/commands/eval.py index 5668ab3452a..6dff29af5af 100644 --- a/tensorrt_llm/commands/eval.py +++ b/tensorrt_llm/commands/eval.py @@ -25,6 +25,7 @@ from ..llmapi import BuildConfig, KvCacheConfig from ..llmapi.llm_utils import update_llm_args_with_extra_options from ..logger import logger, severity_map +from ..mapping import CpType @click.group() @@ -74,6 +75,10 @@ type=int, default=1, help='Pipeline parallelism size.') +@click.option("--cp_size", + type=int, + default=1, + help='Context parallelism size.') @click.option("--ep_size", type=int, default=None, @@ -105,6 +110,10 @@ is_flag=True, default=False, help="Flag for disabling KV cache reuse.") +@click.option("--cp_config", + type=dict, + default=None, + help="Context parallelism configuration as JSON.") @click.pass_context def main(ctx, model: str, tokenizer: Optional[str], log_level: str, backend: str, max_beam_width: int, max_batch_size: int, @@ -112,7 +121,7 @@ def main(ctx, model: str, tokenizer: Optional[str], log_level: str, ep_size: Optional[int], gpus_per_node: Optional[int], kv_cache_free_gpu_memory_fraction: float, trust_remote_code: bool, revision: Optional[str], extra_llm_api_options: Optional[str], - disable_kv_cache_reuse: bool): + disable_kv_cache_reuse: bool, cp_size: int, cp_config: Optional[dict]): logger.set_level(log_level) build_config = BuildConfig(max_batch_size=max_batch_size, max_num_tokens=max_num_tokens, @@ -123,11 +132,20 @@ def main(ctx, model: str, tokenizer: Optional[str], log_level: str, free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction, enable_block_reuse=not disable_kv_cache_reuse) + if cp_config is not None and "cp_type" in cp_config: + cp_config = cp_config.copy() + try: + cp_config["cp_type"] = CpType[cp_config["cp_type"].upper()] + except KeyError: + raise ValueError(f"Invalid cp_type: {cp_config['cp_type']}. " \ + f"Must be one of: {', '.join([t.name for t in CpType])}") llm_args = { "model": model, "tokenizer": tokenizer, "tensor_parallel_size": tp_size, "pipeline_parallel_size": pp_size, + "context_parallel_size": cp_size, + "cp_config": cp_config if cp_config is not None else {}, "moe_expert_parallel_size": ep_size, "gpus_per_node": gpus_per_node, "trust_remote_code": trust_remote_code, diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index e8fd2cef784..a2e4f4f8066 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -32,6 +32,7 @@ from tensorrt_llm.llmapi.mpi_session import find_free_ipc_addr from tensorrt_llm.llmapi.reasoning_parser import ReasoningParserFactory from tensorrt_llm.logger import logger, severity_map +from tensorrt_llm.mapping import CpType from tensorrt_llm.serve import OpenAIDisaggServer, OpenAIServer from tensorrt_llm.serve.tool_parser import ToolParserFactory from tensorrt_llm.tools.importlib_utils import import_custom_module_from_dir @@ -90,6 +91,8 @@ def get_llm_args( max_seq_len: int = BuildConfig.model_fields["max_seq_len"].default, tensor_parallel_size: int = 1, pipeline_parallel_size: int = 1, + context_parallel_size: int = 1, + cp_config: Optional[dict] = None, moe_expert_parallel_size: Optional[int] = None, gpus_per_node: Optional[int] = None, free_gpu_memory_fraction: float = 0.9, @@ -121,12 +124,22 @@ def get_llm_args( capacity_scheduler_policy=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT, dynamic_batch_config=dynamic_batch_config, ) + if cp_config is not None and "cp_type" in cp_config: + cp_config = cp_config.copy() + try: + cp_config["cp_type"] = CpType[cp_config["cp_type"].upper()] + except KeyError: + raise ValueError(f"Invalid cp_type: {cp_config['cp_type']}. " \ + f"Must be one of: {', '.join([t.name for t in CpType])}") + llm_args = { "model": model, "scheduler_config": scheduler_config, "tokenizer": tokenizer, "tensor_parallel_size": tensor_parallel_size, "pipeline_parallel_size": pipeline_parallel_size, + "context_parallel_size": context_parallel_size, + "cp_config": cp_config if cp_config is not None else {}, "moe_expert_parallel_size": moe_expert_parallel_size, "gpus_per_node": gpus_per_node, "trust_remote_code": trust_remote_code, @@ -291,6 +304,10 @@ def convert(self, value: Any, param: Optional["click.Parameter"], type=int, default=1, help='Pipeline parallelism size.') +@click.option("--cp_size", + type=int, + default=1, + help='Context parallelism size.') @click.option("--ep_size", type=int, default=None, @@ -385,7 +402,7 @@ def serve( model: str, tokenizer: Optional[str], host: str, port: int, log_level: str, backend: str, max_beam_width: int, max_batch_size: int, max_num_tokens: int, max_seq_len: int, tp_size: int, pp_size: int, - ep_size: Optional[int], cluster_size: Optional[int], + cp_size: int, ep_size: Optional[int], cluster_size: Optional[int], gpus_per_node: Optional[int], kv_cache_free_gpu_memory_fraction: float, num_postprocess_workers: int, trust_remote_code: bool, revision: Optional[str], extra_llm_api_options: Optional[str], @@ -419,6 +436,7 @@ def serve( max_seq_len=max_seq_len, tensor_parallel_size=tp_size, pipeline_parallel_size=pp_size, + context_parallel_size=cp_size, moe_expert_parallel_size=ep_size, moe_cluster_parallel_size=cluster_size, gpus_per_node=gpus_per_node, diff --git a/tensorrt_llm/llmapi/disagg_utils.py b/tensorrt_llm/llmapi/disagg_utils.py index 1ef5f413973..b38fd514bc5 100644 --- a/tensorrt_llm/llmapi/disagg_utils.py +++ b/tensorrt_llm/llmapi/disagg_utils.py @@ -198,7 +198,7 @@ def extract_ctx_gen_cfgs(type: Literal['ctx', 'gen'], # Compute the number of ranks per instance instance_num_ranks = kwargs.get('tensor_parallel_size', 1) * kwargs.get( - 'pipeline_parallel_size', 1) + 'pipeline_parallel_size', 1) * kwargs.get('context_parallel_size', 1) cfgs = [] for hostname, port in zip(hostnames, ports): diff --git a/tensorrt_llm/mapping.py b/tensorrt_llm/mapping.py index 28982a0caf6..84e5361da34 100644 --- a/tensorrt_llm/mapping.py +++ b/tensorrt_llm/mapping.py @@ -485,6 +485,26 @@ def __init__( enable_attention_dp=enable_attention_dp, enable_lm_head_tp_in_adp=enable_lm_head_tp_in_adp) + def repurpose_helix_cp_to_tp(self): + # In helix parallelism, CP is relevant only for the attention layer. These ranks are repurposed to TP + # for FFN layers. + assert self.has_cp_helix() + return Mapping( + world_size=self.world_size, + rank=self.rank, + gpus_per_node=self.gpus_per_node, + cp_size=1, + cp_config={}, + tp_size=self.tp_size * self.cp_size, + pp_size=self.pp_size, + pp_partition=self.pp_partition, + moe_cluster_size=self.moe_cluster_size, + moe_tp_size=self.moe_tp_size, + moe_ep_size=self.moe_ep_size, + # attn_tp_size, attn_cp_size shall be set in the constructor of Mapping. + enable_attention_dp=self.enable_attention_dp, + enable_lm_head_tp_in_adp=self.enable_lm_head_tp_in_adp) + # DeviceMesh specific methods @property def tp_group_pg(self): diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1cp2_deepseek_v3_lite_bf16_tllm_gen.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1cp2_deepseek_v3_lite_bf16_tllm_gen.yaml new file mode 100644 index 00000000000..7cfb4888717 --- /dev/null +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1cp2_deepseek_v3_lite_bf16_tllm_gen.yaml @@ -0,0 +1,37 @@ +hostname: localhost +port: 8000 +model: DeepSeek-V3-Lite/bf16 +free_gpu_memory_fraction: 0.25 +backend: "pytorch" +disable_overlap_scheduler: True +cuda_graph_config: null +context_servers: + num_instances: 1 + enable_chunked_prefill: False + kv_cache_config: + enable_block_reuse: False + enable_partial_reuse: False + tokens_per_block: 32 + tensor_parallel_size: 2 + pipeline_parallel_size: 1 + cache_transceiver_config: + backend: "UCX" + urls: + - "localhost:8001" +generation_servers: + num_instances: 1 + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + context_parallel_size: 2 + enable_chunked_prefill: False + cp_config: + cp_type: HELIX + tokens_per_block: 32 + kv_cache_config: + enable_block_reuse: False + enable_partial_reuse: False + tokens_per_block: 32 + cache_transceiver_config: + backend: "UCX" + urls: + - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_disaggregated.py b/tests/integration/defs/disaggregated/test_disaggregated.py index cb339dc0a5a..ef3549dc632 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated.py +++ b/tests/integration/defs/disaggregated/test_disaggregated.py @@ -267,6 +267,10 @@ def get_test_config(test_desc, example_dir, test_root): ), "llama4_kv_cache_overflow": (8, f"{test_configs_root}/disagg_config_llama4_kv_cache_overflow.yaml"), + "deepseek_v3_lite_bf16_tllm_gen_helix": + (4, + f"{test_configs_root}/disagg_config_ctxtp2_gentp1cp2_deepseek_v3_lite_bf16_tllm_gen.yaml" + ), } if test_desc not in config_map: @@ -1906,3 +1910,25 @@ def test_llama4_long_context_kv_cache_overflow(disaggregated_test_root, output_tokens=100, env=llm_venv._new_env, cwd=llm_venv.get_working_directory()) + + +@pytest.mark.skip_less_device(4) +@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-bf16'], + indirect=True) +def test_disaggregated_deepseek_v3_lite_bf16_tllm_gen_helix( + disaggregated_test_root, disaggregated_example_root, llm_venv, + deepseek_v3_model_root): + src_dst_dict = { + deepseek_v3_model_root: + f"{llm_venv.get_working_directory()}/DeepSeek-V3-Lite/bf16", + } + for src, dst in src_dst_dict.items(): + if not os.path.islink(dst): + os.makedirs(os.path.dirname(dst), exist_ok=True) + os.symlink(src, dst, target_is_directory=True) + + run_disaggregated_test(disaggregated_example_root, + "deepseek_v3_lite_bf16_tllm_gen_helix", + env=llm_venv._new_env, + cwd=llm_venv.get_working_directory(), + prompt_file="long_prompts.json") diff --git a/tests/unittest/_torch/executor/test_pytorch_model_engine.py b/tests/unittest/_torch/executor/test_pytorch_model_engine.py index 83909bc1caf..ca75cbb3517 100644 --- a/tests/unittest/_torch/executor/test_pytorch_model_engine.py +++ b/tests/unittest/_torch/executor/test_pytorch_model_engine.py @@ -1,6 +1,6 @@ import unittest from dataclasses import dataclass -from unittest.mock import MagicMock, Mock +from unittest.mock import Mock import torch @@ -393,19 +393,6 @@ def test_prepare_tp_inputs_with_helix_parallelism(self) -> None: rank=cp_rank) model_engine.mapping = mapping - # Mock model_engine's dist and its cp_allgather to return different values per CP rank. - mock_dist = MagicMock() - - def mock_cp_allgather(obj): - # Simulate allgather across CP ranks: [past_seen_token_num_rank0, past_seen_token_num_rank1] - if cp_rank == 0: - return [obj, obj + 10] # Rank 0 sees tokens [obj, obj+10] - else: - return [obj - 10, obj] # Rank 1 sees tokens [obj-10, obj] - - mock_dist.cp_allgather.side_effect = mock_cp_allgather - model_engine.dist = mock_dist - # Create scheduled requests with two generation requests. scheduled_requests = ScheduledRequests() scheduled_requests.context_requests = [] @@ -419,6 +406,8 @@ def mock_cp_allgather(obj): req.py_seq_slot = idx req.sampling_config.beam_width = 1 req.py_multimodal_data = {} + req.total_input_len_cp = prompt_lens[idx] * 2 + req.py_decoding_iter = 1 gen_requests.append(req) scheduled_requests.generation_requests = gen_requests @@ -464,15 +453,10 @@ def mock_cp_allgather(obj): self.assertIn('position_ids', result) self.assertIn('attn_metadata', result) - # Check that cp_allgather was called for position calculation. # Also, verify that position_ids are properly calculated. - self.assertTrue(mock_dist.cp_allgather.called) position_ids = result['position_ids'] self.assertIsInstance(position_ids, torch.Tensor) - # For cp_rank=0, the expected position_ids should be: - # req1: past_seen_token_num=19 (prompt_len0-1), allgather=[19, 29], sum=48. - # req2: past_seen_token_num=14 (prompt_len1-1), allgather=[14, 24], sum=38. - expected_positions = [48, 38] + expected_positions = [40, 30] actual_positions = position_ids.squeeze(0).cpu().tolist()[:2] self.assertEqual( actual_positions, expected_positions, diff --git a/tests/unittest/_torch/modules/test_mla_helix.py b/tests/unittest/_torch/modules/test_mla_helix.py index 35d7e8b521d..889cc3c4868 100644 --- a/tests/unittest/_torch/modules/test_mla_helix.py +++ b/tests/unittest/_torch/modules/test_mla_helix.py @@ -99,6 +99,8 @@ def max_position_embeddings(self) -> int: all_scenarios = [ + Scenario(batch=1, ctx_len=64), + Scenario(batch=1, ctx_len=512), Scenario(batch=1, ctx_len=1024), Scenario(batch=1, ctx_len=2048), Scenario(batch=1, ctx_len=4096), @@ -129,14 +131,14 @@ def max_position_embeddings(self) -> int: # limit the number of test scenarios to avoid taking too long test_scenarios = [ - # note: tests with ctx_len=1024 (or less) are currently failing, most likely due to - # bad numerics especially with bf16. We ignore those tests for now. - all_scenarios[2], - all_scenarios[5], - all_scenarios[12], - all_scenarios[15], - all_scenarios[21], - all_scenarios[22], + all_scenarios[0], + all_scenarios[1], + all_scenarios[4], + all_scenarios[7], + all_scenarios[14], + all_scenarios[17], + all_scenarios[23], + all_scenarios[24], ] @@ -501,9 +503,16 @@ def _run_mla_distributed( start = time.time() for step in range(gen_steps): + helix_is_inactive_rank = [] for req_id in range(scenario.batch): kv_cache_manager.impl.add_token(req_id) - cache_add = step if rank == world_size - 1 else 0 + # Assume last rank is active for all gen steps. + if rank == world_size - 1: + helix_is_inactive_rank.append(False) + cache_add = step + else: + helix_is_inactive_rank.append(True) + cache_add = 0 cached_tokens_per_seq = [ctx_len_per_gpu + cache_add for _ in range(scenario.batch)] if step == 0: attn_metadata = get_attention_backend("TRTLLM").Metadata( @@ -519,12 +528,18 @@ def _run_mla_distributed( num_cached_tokens_per_seq=cached_tokens_per_seq, ), enable_context_mla_with_cached_kv=True, + helix_is_inactive_rank=torch.tensor( + helix_is_inactive_rank, dtype=torch.bool, device="cuda" + ), ) else: attn_metadata.kv_cache_params = KVCacheParams( use_cache=True, num_cached_tokens_per_seq=cached_tokens_per_seq, ) + attn_metadata.helix_is_inactive_rank = torch.tensor( + helix_is_inactive_rank, dtype=torch.bool, device="cuda" + ) attn_metadata.prepare() extra_attrs["attention_metadata"] = weakref.ref(attn_metadata) if not use_cuda_graph: