@@ -66,6 +66,7 @@ struct MlaRopeGenArgs
6666 float const * kv_scale_quant_orig_ptr;
6767 float host_bmm1_scale;
6868 int32_t const * helix_position_offsets_ptr;
69+ bool const * helix_is_inactive_rank_ptr;
6970};
7071
7172template <typename T, typename KVCacheBuffer>
@@ -105,6 +106,7 @@ void invokeMLARopeGenerationHelper(T const* latent_cache_ptr, T* q_pe_ptr, T* fu
105106 mla_params.dequant_scale_kv = args.kv_scale_quant_orig_ptr ;
106107 mla_params.host_bmm1_scale = args.host_bmm1_scale ;
107108 mla_params.helix_position_offsets = args.helix_position_offsets_ptr ;
109+ mla_params.helix_is_inactive_rank = args.helix_is_inactive_rank_ptr ;
108110
109111 tk::invokeMLARopeGeneration<T>(mla_params, kv_cache_buffer, stream);
110112}
@@ -134,7 +136,7 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim +
134136 head_size == kv_lora_rank + qk_rope_head_dim, " head_size must = kv_lora_rank + qk_rope_head_dim" );
135137 TLLM_CHECK_WITH_INFO (num_kv_heads == 1 , " num_kv_heads must = 1" );
136138 TORCH_CHECK (
137- mla_tensor_params.size () == 1 , " Expecting 1 tensor for custom MLA tensor params: helix_position_offsets." );
139+ mla_tensor_params.size () == 2 , " Expecting 2 tensors for custom MLA tensor params: helix_position_offsets and helix_is_inactive_rank ." );
138140
139141 auto stream = at::cuda::getCurrentCUDAStream (fused_q.get_device ());
140142 auto const kv_cache_quant_mode = tc::QuantMode (uint32_t (quant_mode));
@@ -153,6 +155,7 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim +
153155 int32_t const num_gen_tokens = num_tokens;
154156 int32_t const seq_offset = num_contexts;
155157 auto & mla_helix_position_offsets = mla_tensor_params[0 ];
158+ auto & mla_helix_is_inactive_rank = mla_tensor_params[1 ];
156159 int32_t const layer_num = host_kv_cache_pool_mapping.value ().size (0 );
157160
158161 tk::MlaMetaParams mla_meta_params = {static_cast <int >(q_lora_rank), static_cast <int >(kv_lora_rank),
@@ -161,6 +164,8 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim +
161164
162165 int32_t const * helix_position_offsets_ptr
163166 = mla_helix_position_offsets.has_value () ? mla_helix_position_offsets->data_ptr <int32_t >() : nullptr ;
167+ bool const * helix_is_inactive_rank_ptr
168+ = mla_helix_is_inactive_rank.has_value () ? mla_helix_is_inactive_rank->data_ptr <bool >() : nullptr ;
164169
165170 int * cu_q_seqlens_ptr = reinterpret_cast <int *>(cu_q_seqlens.data_ptr ());
166171 int * cu_kv_seqlens_ptr = reinterpret_cast <int *>(cu_kv_seqlens.data_ptr ());
@@ -274,7 +279,7 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim +
274279 static_cast <int32_t >(num_heads), mla_meta_params, sequence_lengths_ptr, max_context_q_len,
275280 block_ids_per_seq_ptr, cache_type, cu_q_seqlens_ptr, cu_kv_seqlens_ptr, fmha_tile_counter_ptr,
276281 mla_bmm1_scale_ptr, mla_bmm2_scale_ptr, quant_q_buffer_ptr, quant_scale_o_ptr, kv_scale_orig_quant_ptr,
277- kv_scale_quant_orig_ptr, host_bmm1_scale, helix_position_offsets_ptr};
282+ kv_scale_quant_orig_ptr, host_bmm1_scale, helix_position_offsets_ptr, helix_is_inactive_rank_ptr };
278283
279284 auto const input_dtype = fused_q.scalar_type ();
280285 if (input_dtype == torch::kFloat16 )
0 commit comments