Skip to content

Commit 6d9b911

Browse files
committed
TRTLLM-7731 KV cache transmission in disagg with CP on gen side
Signed-off-by: Balaram Buddharaju <[email protected]> add ds-lite tllm-gen based disagg test Signed-off-by: Matthias Jouanneaux <[email protected]> initial support for helix parallelism Signed-off-by: Matthias Jouanneaux <[email protected]> fixed mapping tests, added working MLA module test, added disagg test for helix (WIP) Signed-off-by: Matthias Jouanneaux <[email protected]> Helix MLA module test: added more scenarios, removed unnecessary code Signed-off-by: Matthias Jouanneaux <[email protected]> MLA Helix test: restricting number of tests, better output Signed-off-by: Matthias Jouanneaux <[email protected]> test MLA helix: remove OOM test scenario Signed-off-by: Matthias Jouanneaux <[email protected]> test MLA helix: fix scenario max position embeddings Signed-off-by: Matthias Jouanneaux <[email protected]> test Helix MLA: try to fix NaNs Signed-off-by: Matthias Jouanneaux <[email protected]> added all-to-all impl Signed-off-by: Matthias Jouanneaux <[email protected]> fix thop lib Signed-off-by: Matthias Jouanneaux <[email protected]> fix alltoall Signed-off-by: Matthias Jouanneaux <[email protected]> attention MLA: remove kv heads (unused), improve heads naming, fix tests Signed-off-by: Matthias Jouanneaux <[email protected]> test Helix MLA: minor fixes Signed-off-by: Matthias Jouanneaux <[email protected]> test Helix MLA: disable numeric test Signed-off-by: Matthias Jouanneaux <[email protected]> test Helix MLA: add TODOs to MLA module Signed-off-by: Matthias Jouanneaux <[email protected]> test Helix MLA: fix MLA module Signed-off-by: Matthias Jouanneaux <[email protected]> debugging Signed-off-by: Matthias Jouanneaux <[email protected]> debugging Signed-off-by: Matthias Jouanneaux <[email protected]> debugging Signed-off-by: Matthias Jouanneaux <[email protected]> debugging Signed-off-by: Matthias Jouanneaux <[email protected]> debugging Signed-off-by: Matthias Jouanneaux <[email protected]> debugging Signed-off-by: Matthias Jouanneaux <[email protected]> debugging Signed-off-by: Matthias Jouanneaux <[email protected]> debugging Signed-off-by: Matthias Jouanneaux <[email protected]> debugging Signed-off-by: Matthias Jouanneaux <[email protected]> debugging Signed-off-by: Matthias Jouanneaux <[email protected]> debugging Signed-off-by: Matthias Jouanneaux <[email protected]> debugging Signed-off-by: Matthias Jouanneaux <[email protected]> debugging Signed-off-by: Matthias Jouanneaux <[email protected]> debugging Signed-off-by: Matthias Jouanneaux <[email protected]> debugging Signed-off-by: Matthias Jouanneaux <[email protected]> debugging Signed-off-by: Matthias Jouanneaux <[email protected]> debugging Signed-off-by: Matthias Jouanneaux <[email protected]> debugging Signed-off-by: Matthias Jouanneaux <[email protected]> fully working MLA test Signed-off-by: Matthias Jouanneaux <[email protected]> attempt to make latent cache work Signed-off-by: Matthias Jouanneaux <[email protected]> debugging numerical issue Signed-off-by: Matthias Jouanneaux <[email protected]> debugging numerical issue Signed-off-by: Matthias Jouanneaux <[email protected]> debugging numerical issue Signed-off-by: Matthias Jouanneaux <[email protected]> debugging numerical issue Signed-off-by: Matthias Jouanneaux <[email protected]> debugging numerical issue Signed-off-by: Matthias Jouanneaux <[email protected]> adding additional test for further numerical debugging Signed-off-by: Matthias Jouanneaux <[email protected]> fixing tests & correction Signed-off-by: Matthias Jouanneaux <[email protected]> remove debug output from tests Signed-off-by: Matthias Jouanneaux <[email protected]> fix tests Signed-off-by: Matthias Jouanneaux <[email protected]> further debugging with multiple sequences Signed-off-by: Matthias Jouanneaux <[email protected]> further debugging with multiple sequences Signed-off-by: Matthias Jouanneaux <[email protected]> further debugging with multiple sequences Signed-off-by: Matthias Jouanneaux <[email protected]> fixed multiple sequences tests Signed-off-by: Matthias Jouanneaux <[email protected]> automated review comments Signed-off-by: Matthias Jouanneaux <[email protected]> debugging of latent cache Signed-off-by: Matthias Jouanneaux <[email protected]> debugging of latent cache Signed-off-by: Matthias Jouanneaux <[email protected]> further debugging of pe values Signed-off-by: Matthias Jouanneaux <[email protected]> further debugging of latent cache Signed-off-by: Matthias Jouanneaux <[email protected]> fixed latent cache, remove flaky test Signed-off-by: Matthias Jouanneaux <[email protected]> better reporting Signed-off-by: Matthias Jouanneaux <[email protected]> better reporting Signed-off-by: Matthias Jouanneaux <[email protected]> finalized test scenarios Signed-off-by: Matthias Jouanneaux <[email protected]> better perf measurements, added graph support Signed-off-by: Matthias Jouanneaux <[email protected]> added helix post process kernel Signed-off-by: Matthias Jouanneaux <[email protected]> added unit test, minor fix for helix kernel Signed-off-by: Matthias Jouanneaux <[email protected]> fixing helix kernels Signed-off-by: Matthias Jouanneaux <[email protected]> better tests, minor fixes Signed-off-by: Matthias Jouanneaux <[email protected]> better tests, minor fixes Signed-off-by: Matthias Jouanneaux <[email protected]> debugging helix test Signed-off-by: Matthias Jouanneaux <[email protected]> debugging helix test Signed-off-by: Matthias Jouanneaux <[email protected]> debugging helix test Signed-off-by: Matthias Jouanneaux <[email protected]> fixed helix post process kernel: main kernel had perf issue/flaw Signed-off-by: Matthias Jouanneaux <[email protected]> fixed helix post process test Signed-off-by: Matthias Jouanneaux <[email protected]> added helix full layer test Signed-off-by: Matthias Jouanneaux <[email protected]> fix full layer helix test/bench Signed-off-by: Matthias Jouanneaux <[email protected]> added correct mapping to ds helix Signed-off-by: Matthias Jouanneaux <[email protected]> further improvements for fp8 init Signed-off-by: Matthias Jouanneaux <[email protected]> debugging quantization config Signed-off-by: Matthias Jouanneaux <[email protected]> better debug output Signed-off-by: Matthias Jouanneaux <[email protected]> fixes for fp8 Signed-off-by: Matthias Jouanneaux <[email protected]> fix fp8 runs Signed-off-by: Matthias Jouanneaux <[email protected]> attempt to fix fp8 context Signed-off-by: Matthias Jouanneaux <[email protected]> fix context phase: just randomly gen kv cache values. fix scenario sizes Signed-off-by: Matthias Jouanneaux <[email protected]> fix tp size config in helix layer test Signed-off-by: Matthias Jouanneaux <[email protected]> minor changes for test get trtllm-serve working with BF16 for gen with cp - v_b_proj weight loading needs to be revisited $ CUDA_VISIBLE_DEVICES=0,1 trtllm-serve /home/scratch.trt_llm_data/llm-models/DeepSeek-V3-Lite/bf16/ --host localhost --port 8002 --cp_size 2 --extra_llm_api_options ./gen_extra-llm-api-config.yaml end-to-end test in disagg works $ pytest tests/integration/defs/disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_bf16_tllm_gen_helix -s -v Switch to contiguous block dist among CP rank save changes to _merge_requests() undo changes to prepare_inputs() Raise exception for blocks fewer than num_cp_ranks save intermediate changes attempt to fix attention tests Signed-off-by: Matthias Jouanneaux <[email protected]> save changes for minimal test save minor dev comments added helix inactive rank option to MLA kernels Signed-off-by: Matthias Jouanneaux <[email protected]> pass the right seq_lens_kv - test with seqlen 64 works $ pytest tests/unittest/_torch/modules/test_mla_helix_expt.py -s -v is_inactive_helix at request level cp_allgather for position_id helix: make inactive rank a bool tensor Signed-off-by: Matthias Jouanneaux <[email protected]> undo mapping changes to modeling_deepseek Failed attempt to replace model_config.mapping fill in helix_is_inactive for each request update position_id logic better way to package mapping - repurpose comms creation too save disagg gen-only benchmark test prep for integration test improvements to position_id, num_cached_tokens_per_seq and tokens_per_block changes to save blocks at prefill changes to save blocks at decode add changes to read KV from disk updates to save and read KV blocks for all layers over-allocate at prefill to get cache transmission right prune saved KV cache files updates to avoid over-allocation on gen side in disagg Revert "over-allocate at prefill to get cache transmission right" This reverts commit af7d000. save disagg configs for DSV3 - currently goes OOM verifying tests on 8 GPUs helix: added (working) DS R1 8-GPU integration test Signed-off-by: Matthias Jouanneaux <[email protected]> helix: added large prompt + ds lite config using large prompt Signed-off-by: Matthias Jouanneaux <[email protected]> save intermediate changes for fixes fix debug printing Signed-off-by: Matthias Jouanneaux <[email protected]> Mention cache_transceiver_config.max_tokens_in_buffer for disagg servers save initial changes to benchmarking script added mjoux specific submit script, tighter timeouts, better defaults Signed-off-by: Matthias Jouanneaux <[email protected]> helix slurm: increase timeouts slightly, use deepgemm moe backend for smaller models Signed-off-by: Matthias Jouanneaux <[email protected]> helix slurm: add dataset caching path Signed-off-by: Matthias Jouanneaux <[email protected]> fix padding when input_len is divisible by tokens_per_block save changes to test varying prompt len fix_kvcache_split Signed-off-by: Chuang Zhu <[email protected]> avoid fabric memory and print send and recv sizes auto-determine transceiver size Signed-off-by: Matthias Jouanneaux <[email protected]> remove verbose print output Signed-off-by: Matthias Jouanneaux <[email protected]> attempt to fix DS R1 run Signed-off-by: Matthias Jouanneaux <[email protected]> helix slurm: fix parameters for DS R1 up to 256K tokens Signed-off-by: Matthias Jouanneaux <[email protected]> minor updates to reduce memory footprint and bring back warmup enable cudagraph and add some debug prints ugly hack to get results with 512k updates to benchmark 1M seqlen updates to benchmark 2M seqlen updates for passing down moe properly minor changes to get nsys profiles test helix layer: support for slurm call, support for fp4 Signed-off-by: Matthias Jouanneaux <[email protected]> test helix layer: added sbatch script Signed-off-by: Matthias Jouanneaux <[email protected]> add minimal cache transmission test for 1M seqlen minor bug fix changes to benchmark 4M seqlen skip launch/wait of context servers when TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1 remove hacks; skip profiling; gpu_mem_frac test helix layer: fix nvfp4 config to fit high perf mode Signed-off-by: Matthias Jouanneaux <[email protected]> helix single layer: improved timing, added arg parsing, added output parsing Signed-off-by: Matthias Jouanneaux <[email protected]> helix single layer: add dense option Signed-off-by: Matthias Jouanneaux <[email protected]> helix slurm: fix gen_only config, support EP config, add submit script for multiple configs, remove build_wheel by default for array benchmarking Signed-off-by: Matthias Jouanneaux <[email protected]> helix slurm: added parse script for results Signed-off-by: Matthias Jouanneaux <[email protected]> helix single layer: fixed test, added config submit script, improved parsing Signed-off-by: Matthias Jouanneaux <[email protected]> helix single layer: fix segment for sbatch script Signed-off-by: Matthias Jouanneaux <[email protected]> helix: fixed TP-only runs (removed hack to make higher seq len work), improved sbatch scripts Signed-off-by: Matthias Jouanneaux <[email protected]> helix: fix high node count runs, move back to e2e mode, improve parse script Signed-off-by: Matthias Jouanneaux <[email protected]> longer prompt for DSV3 Lite & DSR1 FP4 integration test disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_bf16_tllm_gen_helix disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tllm_gen_helix disaggregated/test_disaggregated.py::test_disaggregated_deepseek_r1_fp4_tllm_gen_helix helix: added initial README for testing/benchmarking Signed-off-by: Matthias Jouanneaux <[email protected]> helix slurm: remove references to internal clusters Signed-off-by: Matthias Jouanneaux <[email protected]> minor updates to README minor updates helix: improve transpose/split for alltoall Signed-off-by: Matthias Jouanneaux <[email protected]> Revert "helix: improve transpose/split for alltoall" This reverts commit c8b24b9. helix: improve alltoall perf Signed-off-by: Matthias Jouanneaux <[email protected]> [https://nvbugs/5495789][feat] Optionally disable server GC and worker GC (#7995) Signed-off-by: Tailing Yuan <[email protected]> save changes for custom logging redo cherry-pick of attention.py save more changes for build and pipe-cleaning save more changes clean up - 1 clean up - 2 reuse mla_tensor_params instead of using helix_tensor_params undo all_tp_rank_num_tokens update test_disaggregated.py updates to dsv3RopeOp more cleanup save fp8 disagg test [https://nvbugs/5637012][fix] Fix helix unit tests Signed-off-by: Balaram Buddharaju <[email protected]> minor updates to attention.py updates to test - seqlen 64 works get integration test working
1 parent 268ea9b commit 6d9b911

File tree

22 files changed

+383
-60
lines changed

22 files changed

+383
-60
lines changed

cpp/tensorrt_llm/kernels/mlaKernels.cu

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,8 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe,
351351
int* seqQOffset, uint32_t* fmha_tile_counter, int32_t const* kv_cache_lengths, int* seqKVOffsets, int q_pe_ld,
352352
int q_pe_stride, KvCacheDataType cache_type, float* bmm1_scale, float* bmm2_scale, float const* quant_scale_o,
353353
float const* quant_scale_q, float const* quant_scale_kv, float const* dequant_scale_q,
354-
float const* dequant_scale_kv, float host_bmm1_scale, int32_t const* helix_position_offsets)
354+
float const* dequant_scale_kv, float host_bmm1_scale, int32_t const* helix_position_offsets,
355+
bool const* helix_is_inactive_rank)
355356
{
356357

357358
// Constants.
@@ -424,7 +425,6 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe,
424425

425426
if (valid_token)
426427
{
427-
428428
auto const position_id
429429
= (helix_position_offsets != nullptr ? helix_position_offsets[global_token_idx]
430430
: kv_cache_lengths[batch_idx] - seq_len + local_token_idx);
@@ -460,10 +460,9 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe,
460460

461461
if (valid_token)
462462
{
463-
if (head_idx == head_num)
463+
if (head_idx == head_num && (helix_is_inactive_rank == nullptr || !helix_is_inactive_rank[batch_idx]))
464464
{
465465
auto const token_kv_idx = kv_cache_lengths[batch_idx] - seq_len + local_token_idx;
466-
467466
{
468467
auto kDst = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_kv_idx));
469468
auto inBlockIdx = kv_cache.getKVLocalIdx(
@@ -514,7 +513,7 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe,
514513
auto local_token_idx = global_token_idx % seq_len;
515514
bool valid_token = global_token_idx < total_s_len;
516515

517-
if (valid_token)
516+
if (valid_token && (helix_is_inactive_rank == nullptr || !helix_is_inactive_rank[batch_idx]))
518517
{
519518
if (head_dim_vec_idx == 0)
520519
{
@@ -1047,7 +1046,7 @@ void invokeMLARopeGeneration(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer
10471046
params.seqQOffset, params.fmha_tile_counter, params.cache_seq_lens, params.cu_kv_seqlens, params.q_pe_ld,
10481047
params.q_pe_stride, params.cache_type, params.bmm1_scale, params.bmm2_scale, params.quant_scale_o,
10491048
params.quant_scale_q, params.quant_scale_kv, params.dequant_scale_q, params.dequant_scale_kv,
1050-
params.host_bmm1_scale, params.helix_position_offsets);
1049+
params.host_bmm1_scale, params.helix_position_offsets, params.helix_is_inactive_rank);
10511050
}
10521051

10531052
template <typename T, typename TCache>

cpp/tensorrt_llm/kernels/mlaKernels.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ struct MlaParams
107107

108108
// for Helix parallelism: the rotary position offsets [b]
109109
int32_t const* helix_position_offsets{nullptr};
110+
// for Helix parallelism: whether the current rank is inactive, shape [b]
111+
// (the current query tokens are not appended to this rank's KV cache)
112+
bool const* helix_is_inactive_rank{nullptr};
110113
};
111114

112115
template <typename T, typename KVCacheBuffer>

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ class Runner : public RunnerBase
181181
[[maybe_unused]] MlaParams<T> mla_params;
182182
if (op.isMLAEnabled())
183183
{
184-
TORCH_CHECK(mla_tensor_params.size() == 1,
185-
"Expecting 1 tensor for custom MLA tensor params: helix_position_offsets.");
184+
TORCH_CHECK(mla_tensor_params.size() == 2,
185+
"Expecting 2 tensors for custom MLA tensor params: helix_position_offsets and helix_is_inactive_rank.");
186186
if (is_context && op.mUseSparseAttention)
187187
{
188188
if (latent_cache.has_value())
@@ -227,10 +227,15 @@ class Runner : public RunnerBase
227227

228228
// For generation, helix position is in ropeOp
229229
auto& mla_helix_position_offsets = mla_tensor_params[0];
230+
auto& mla_helix_is_inactive_rank = mla_tensor_params[1];
230231
if (mla_helix_position_offsets.has_value())
231232
{
232233
mla_params.helix_position_offsets = mla_helix_position_offsets->data_ptr<int32_t>();
233234
}
235+
if (mla_helix_is_inactive_rank.has_value())
236+
{
237+
mla_params.helix_is_inactive_rank = mla_helix_is_inactive_rank->data_ptr<bool>();
238+
}
234239
}
235240
else
236241
{

cpp/tensorrt_llm/thop/dsv3RopeOp.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ struct MlaRopeGenArgs
6666
float const* kv_scale_quant_orig_ptr;
6767
float host_bmm1_scale;
6868
int32_t const* helix_position_offsets_ptr;
69+
bool const* helix_is_inactive_rank_ptr;
6970
};
7071

7172
template <typename T, typename KVCacheBuffer>
@@ -105,6 +106,7 @@ void invokeMLARopeGenerationHelper(T const* latent_cache_ptr, T* q_pe_ptr, T* fu
105106
mla_params.dequant_scale_kv = args.kv_scale_quant_orig_ptr;
106107
mla_params.host_bmm1_scale = args.host_bmm1_scale;
107108
mla_params.helix_position_offsets = args.helix_position_offsets_ptr;
109+
mla_params.helix_is_inactive_rank = args.helix_is_inactive_rank_ptr;
108110

109111
tk::invokeMLARopeGeneration<T>(mla_params, kv_cache_buffer, stream);
110112
}
@@ -134,7 +136,7 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim +
134136
head_size == kv_lora_rank + qk_rope_head_dim, "head_size must = kv_lora_rank + qk_rope_head_dim");
135137
TLLM_CHECK_WITH_INFO(num_kv_heads == 1, "num_kv_heads must = 1");
136138
TORCH_CHECK(
137-
mla_tensor_params.size() == 1, "Expecting 1 tensor for custom MLA tensor params: helix_position_offsets.");
139+
mla_tensor_params.size() == 2, "Expecting 2 tensors for custom MLA tensor params: helix_position_offsets and helix_is_inactive_rank.");
138140

139141
auto stream = at::cuda::getCurrentCUDAStream(fused_q.get_device());
140142
auto const kv_cache_quant_mode = tc::QuantMode(uint32_t(quant_mode));
@@ -153,6 +155,7 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim +
153155
int32_t const num_gen_tokens = num_tokens;
154156
int32_t const seq_offset = num_contexts;
155157
auto& mla_helix_position_offsets = mla_tensor_params[0];
158+
auto& mla_helix_is_inactive_rank = mla_tensor_params[1];
156159
int32_t const layer_num = host_kv_cache_pool_mapping.value().size(0);
157160

158161
tk::MlaMetaParams mla_meta_params = {static_cast<int>(q_lora_rank), static_cast<int>(kv_lora_rank),
@@ -161,6 +164,8 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim +
161164

162165
int32_t const* helix_position_offsets_ptr
163166
= mla_helix_position_offsets.has_value() ? mla_helix_position_offsets->data_ptr<int32_t>() : nullptr;
167+
bool const* helix_is_inactive_rank_ptr
168+
= mla_helix_is_inactive_rank.has_value() ? mla_helix_is_inactive_rank->data_ptr<bool>() : nullptr;
164169

165170
int* cu_q_seqlens_ptr = reinterpret_cast<int*>(cu_q_seqlens.data_ptr());
166171
int* cu_kv_seqlens_ptr = reinterpret_cast<int*>(cu_kv_seqlens.data_ptr());
@@ -274,7 +279,7 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim +
274279
static_cast<int32_t>(num_heads), mla_meta_params, sequence_lengths_ptr, max_context_q_len,
275280
block_ids_per_seq_ptr, cache_type, cu_q_seqlens_ptr, cu_kv_seqlens_ptr, fmha_tile_counter_ptr,
276281
mla_bmm1_scale_ptr, mla_bmm2_scale_ptr, quant_q_buffer_ptr, quant_scale_o_ptr, kv_scale_orig_quant_ptr,
277-
kv_scale_quant_orig_ptr, host_bmm1_scale, helix_position_offsets_ptr};
282+
kv_scale_quant_orig_ptr, host_bmm1_scale, helix_position_offsets_ptr, helix_is_inactive_rank_ptr};
278283

279284
auto const input_dtype = fused_q.scalar_type();
280285
if (input_dtype == torch::kFloat16)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
["Global warming is the long term rise in Earth temperature caused by greenhouse gases from human activity, burning fossil fuels, and deforestation. It leads to melting ice, rising seas, and extreme weather that threaten ecosystems, wildlife, and people. Urgent global action is "]

examples/llm-api/quickstart_advanced.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def add_llm_args(parser):
7070
choices=["auto", "TorchSampler", "TRTLLMSampler"])
7171
parser.add_argument('--tp_size', type=int, default=1)
7272
parser.add_argument('--pp_size', type=int, default=1)
73+
parser.add_argument('--cp_size', type=int, default=1)
7374
parser.add_argument('--moe_ep_size', type=int, default=-1)
7475
parser.add_argument('--moe_tp_size', type=int, default=-1)
7576
parser.add_argument('--moe_cluster_size', type=int, default=-1)
@@ -259,6 +260,7 @@ def setup_llm(args, **kwargs):
259260
attention_dp_config=attention_dp_config,
260261
tensor_parallel_size=args.tp_size,
261262
pipeline_parallel_size=args.pp_size,
263+
context_parallel_size=args.cp_size,
262264
moe_expert_parallel_size=args.moe_ep_size,
263265
moe_tensor_parallel_size=args.moe_tp_size,
264266
moe_cluster_parallel_size=args.moe_cluster_size,

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ def plan(
187187
q_pe: Optional[torch.Tensor] = None,
188188
mrope_config: Optional[dict] = None,
189189
softmax_stats_tensor: Optional[torch.Tensor] = None,
190-
helix_position_offsets: Optional[torch.Tensor] = None,
191190
is_spec_decoding_enabled: bool = False,
192191
use_spec_decoding: bool = False,
193192
is_spec_dec_tree: bool = False,
@@ -205,6 +204,8 @@ def plan(
205204
sparse_attn_offsets: Optional[torch.Tensor] = None,
206205
sparse_attn_indices_block_size: int = 1,
207206
sparse_mla_topk: int = 0,
207+
helix_position_offsets: Optional[torch.Tensor] = None,
208+
helix_is_inactive_rank: Optional[torch.Tensor] = None,
208209
**kwargs,
209210
):
210211
"""
@@ -241,7 +242,6 @@ def plan(
241242
use_paged_context_fmha (bool): Sets the mPagedContextFMHA attribute in the op runner.
242243
mrope_config (dict): The dictionary containing the mRope configuration.
243244
softmax_stats_tensor (torch.Tensor): The tensor to store the softmax statistics (max/sum)
244-
helix_position_offsets (torch.Tensor): The tensor to store the helix position offsets, with shape (num_tokens) on GPU.
245245
attention_sinks (torch.Tensor): The attention sinks (additional value in the denominator of the softmax) with shape of (num_heads_q) on GPU.
246246
chunked_prefill_buffer_batch_size (int): used for malloc buffer for k and v in fp8 context mla. the max input kv length is not max_num_tokens in this case. It is chunked_prefill_buffer_batch_size * max_num_tokens.
247247
sparse_kv_indices (torch.Tensor): The sparse indices for the KV cache, with shape of (num_heads_kv, num_sparse_tokens) on GPU.
@@ -250,6 +250,8 @@ def plan(
250250
sparse_attn_offsets (torch.Tensor): The batch offsets for the sparse attention indices, with shape of (num_generations + 1) on GPU.
251251
sparse_attn_indices_block_size (int): The granularity of the sparse attention indices, used by block sparse attention.
252252
sparse_mla_topk (int): The topk for the sparse MLA, used by DSA attention.
253+
helix_position_offsets (torch.Tensor): The tensor to store the helix position offsets, with shape (num_tokens) on GPU.
254+
helix_is_inactive_rank (torch.Tensor): For Helix: whether the current rank is inactive, with shape (batch_size) on GPU.
253255
"""
254256
self.layer_idx = layer_idx
255257
self.tokens_per_block = tokens_per_block
@@ -285,14 +287,19 @@ def plan(
285287
'mrope_position_deltas') if mrope_config is not None else None
286288
self.block_ids_per_seq = block_ids_per_seq
287289
self.softmax_stats_tensor = softmax_stats_tensor
288-
self.helix_position_offsets = helix_position_offsets
289290
self.attention_sinks = attention_sinks
290291
self.sparse_kv_indices = sparse_kv_indices
291292
self.sparse_kv_offsets = sparse_kv_offsets
292293
self.sparse_attn_indices = sparse_attn_indices
293294
self.sparse_attn_offsets = sparse_attn_offsets
294295
self.sparse_attn_indices_block_size = sparse_attn_indices_block_size
295296
self.sparse_mla_topk = sparse_mla_topk
297+
self.helix_position_offsets = helix_position_offsets
298+
self.helix_is_inactive_rank = helix_is_inactive_rank
299+
if self.helix_is_inactive_rank is not None and not isinstance(self.helix_is_inactive_rank, torch.Tensor):
300+
self.helix_is_inactive_rank = torch.tensor(
301+
self.helix_is_inactive_rank, dtype=torch.bool, pin_memory=True)
302+
296303
if max_sequence_length > self.rope_params.max_positions:
297304
self.rope_params.max_positions = max_sequence_length
298305
self.rotary_inv_freq, self.rotary_cos_sin = self.rope_params.create_rope_const_params(
@@ -471,7 +478,7 @@ def run(
471478
spec_decoding_tensor_params.append(self.spec_decoding_bl_tree_mask)
472479
spec_decoding_tensor_params.append(
473480
self.spec_bl_tree_first_sparse_mask_offset_kv)
474-
mla_tensor_params = [self.helix_position_offsets]
481+
mla_tensor_params = [self.helix_position_offsets, self.helix_is_inactive_rank]
475482

476483
thop.attention(
477484
q,
@@ -630,6 +637,13 @@ class TrtllmAttentionMetadata(AttentionMetadata):
630637
spec_decoding_bl_tree_mask: Optional[torch.Tensor] = None
631638
spec_bl_tree_first_sparse_mask_offset_kv: Optional[torch.Tensor] = None
632639

640+
# Whether the current rank is inactive for helix parallelism.
641+
# In helix parallelism, only the active rank appends KV cache for the query token
642+
# and attends to the previously cached tokens as well as the query token. Inactive
643+
# ranks do not append KV cache for the query token and attend to the previously
644+
# cached tokens only.
645+
helix_is_inactive_rank: Optional[torch.Tensor] = None
646+
633647
@property
634648
def max_seq_len(self) -> int:
635649
"""
@@ -838,7 +852,23 @@ def prepare(self) -> None:
838852
if self.enable_flash_mla:
839853
self.prepare_flash_mla()
840854
# number of tokens needed in the kv cache for each sequence after the next pass
841-
kv_lens = cached_token_lens + self.seq_lens_kv if cached_token_lens is not None else self.seq_lens_kv
855+
if self.helix_is_inactive_rank is not None and len(
856+
self.helix_is_inactive_rank):
857+
# If helix is inactive, attend to the previously cached tokens only.
858+
# This gets further complicated with multiple requests as each request might
859+
# have a different active helix rank.
860+
assert cached_token_lens is not None, "cached_token_lens should be set for helix"
861+
kv_lens = cached_token_lens
862+
helix_is_inactive_rank_cpu = torch.tensor(
863+
self.helix_is_inactive_rank,
864+
dtype=torch.bool,
865+
device='cpu',
866+
)
867+
active_rank = ~helix_is_inactive_rank_cpu
868+
kv_lens[active_rank] += self.seq_lens_kv[active_rank]
869+
else:
870+
kv_lens = cached_token_lens + self.seq_lens_kv if cached_token_lens is not None else self.seq_lens_kv
871+
842872
# self.kv_lens is the valid kv cache length, while the self.kv_lens_cuda is
843873
# the sequence length including the cached tokens and the input tokens.
844874
self.kv_lens[:self.num_seqs].copy_(
@@ -1435,7 +1465,6 @@ def forward(
14351465
q_pe=q_pe,
14361466
mrope_config=mrope_config,
14371467
softmax_stats_tensor=softmax_stats_tensor,
1438-
helix_position_offsets=helix_position_offsets,
14391468
is_spec_decoding_enabled=metadata.is_spec_decoding_enabled,
14401469
use_spec_decoding=metadata.use_spec_decoding,
14411470
is_spec_dec_tree=metadata.is_spec_dec_tree,
@@ -1458,6 +1487,8 @@ def forward(
14581487
sparse_attn_indices_block_size=sparse_attn_indices_block_size,
14591488
sparse_mla_topk=metadata.sparse_mla_topk if hasattr(
14601489
metadata, 'sparse_mla_topk') else 0,
1490+
helix_position_offsets=helix_position_offsets,
1491+
helix_is_inactive_rank=metadata.helix_is_inactive_rank,
14611492
)
14621493
out_dtype = None
14631494
if out_scale is not None:
@@ -1717,6 +1748,7 @@ def mla_rope_generation(
17171748
mla_bmm2_scale: torch.Tensor,
17181749
quant_q_buffer: torch.Tensor,
17191750
helix_position_offsets: Optional[torch.Tensor] = None,
1751+
helix_is_inactive_rank: Optional[torch.Tensor] = None,
17201752
out_scale: Optional[torch.Tensor] = None,
17211753
) -> None:
17221754
"""
@@ -1736,7 +1768,16 @@ def mla_rope_generation(
17361768
assert self.is_mla_enable and self.mla_params is not None
17371769
assert metadata.kv_cache_manager is not None
17381770
sink_token_length = 0
1739-
mla_tensor_params = [helix_position_offsets]
1771+
1772+
# Ensure helix_is_inactive_rank is on the same device as other tensors
1773+
if helix_is_inactive_rank is not None:
1774+
if isinstance(helix_is_inactive_rank, list):
1775+
helix_is_inactive_rank = torch.tensor(
1776+
helix_is_inactive_rank, dtype=torch.bool, device=helix_position_offsets.device)
1777+
elif helix_is_inactive_rank.device.type != 'cuda':
1778+
helix_is_inactive_rank = helix_is_inactive_rank.to(helix_position_offsets.device)
1779+
1780+
mla_tensor_params = [helix_position_offsets, helix_is_inactive_rank]
17401781

17411782
torch.ops.trtllm.mla_rope_generation(
17421783
fused_q,

tensorrt_llm/_torch/distributed/communicator.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from functools import wraps
55
from typing import Optional
66

7+
import copy
78
import numpy as np
89
import torch
910
import torch.distributed as dist
@@ -341,9 +342,30 @@ class MPIDist(Distributed):
341342

342343
def __init__(self, mapping: Mapping):
343344
super().__init__(mapping)
345+
self.create_cp_comm()
346+
# Repurpose CP ranks to TP for Helix so that the right comms are created.
347+
mapping_with_helix = None
348+
if self.mapping.cp_size > 1:
349+
print(f"[MPIDist::__init__] Repurposing CP ranks to TP for Helix.")
350+
mapping_with_helix = copy.deepcopy(self.mapping)
351+
mapping_without_helix = Mapping(
352+
world_size=self.mapping.world_size,
353+
rank=self.mapping.rank,
354+
gpus_per_node=self.mapping.gpus_per_node,
355+
cp_size=1,
356+
cp_config={},
357+
tp_size=self.mapping.tp_size * self.mapping.cp_size,
358+
pp_size=self.mapping.pp_size,
359+
moe_ep_size=self.mapping.moe_ep_size,
360+
enable_attention_dp=self.mapping.enable_attention_dp)
361+
self.mapping = mapping_without_helix
344362
self.create_tp_comm()
345363
self.create_pp_comm()
346-
self.create_cp_comm()
364+
365+
# Restore the original mapping.
366+
if mapping_with_helix is not None:
367+
print(f"[MPIDist::__init__] Restoring original mapping.")
368+
self.mapping = mapping_with_helix
347369

348370
def broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
349371
comm = mpi_comm()

0 commit comments

Comments
 (0)