Skip to content

Commit 2b1c9dd

Browse files
committed
formatting
1 parent b096b89 commit 2b1c9dd

File tree

12 files changed

+71
-56
lines changed

12 files changed

+71
-56
lines changed

cpp/tensorrt_llm/kernels/mlaKernels.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe,
425425

426426
if (valid_token)
427427
{
428+
428429
auto const position_id
429430
= (helix_position_offsets != nullptr ? helix_position_offsets[global_token_idx]
430431
: kv_cache_lengths[batch_idx] - seq_len + local_token_idx);
@@ -463,6 +464,7 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe,
463464
if (head_idx == head_num && (helix_is_inactive_rank == nullptr || !helix_is_inactive_rank[batch_idx]))
464465
{
465466
auto const token_kv_idx = kv_cache_lengths[batch_idx] - seq_len + local_token_idx;
467+
466468
{
467469
auto kDst = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_kv_idx));
468470
auto inBlockIdx = kv_cache.getKVLocalIdx(

cpp/tensorrt_llm/kernels/mlaKernels.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ struct MlaParams
107107

108108
// for Helix parallelism: the rotary position offsets [b]
109109
int32_t const* helix_position_offsets{nullptr};
110+
110111
// for Helix parallelism: whether the current rank is inactive, shape [b]
111112
// (the current query tokens are not appended to this rank's KV cache)
112113
bool const* helix_is_inactive_rank{nullptr};

cpp/tensorrt_llm/thop/dsv3RopeOp.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim +
135135
TLLM_CHECK_WITH_INFO(
136136
head_size == kv_lora_rank + qk_rope_head_dim, "head_size must = kv_lora_rank + qk_rope_head_dim");
137137
TLLM_CHECK_WITH_INFO(num_kv_heads == 1, "num_kv_heads must = 1");
138-
TORCH_CHECK(
139-
mla_tensor_params.size() == 2, "Expecting 2 tensors for custom MLA tensor params: helix_position_offsets and helix_is_inactive_rank.");
138+
TORCH_CHECK(mla_tensor_params.size() == 2,
139+
"Expecting 2 tensors for custom MLA tensor params: helix_position_offsets and helix_is_inactive_rank.");
140140

141141
auto stream = at::cuda::getCurrentCUDAStream(fused_q.get_device());
142142
auto const kv_cache_quant_mode = tc::QuantMode(uint32_t(quant_mode));

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,8 @@ def plan(
296296
self.sparse_mla_topk = sparse_mla_topk
297297
self.helix_position_offsets = helix_position_offsets
298298
self.helix_is_inactive_rank = helix_is_inactive_rank
299-
if self.helix_is_inactive_rank is not None and not isinstance(self.helix_is_inactive_rank, torch.Tensor):
299+
if self.helix_is_inactive_rank is not None and not isinstance(
300+
self.helix_is_inactive_rank, torch.Tensor):
300301
self.helix_is_inactive_rank = torch.tensor(
301302
self.helix_is_inactive_rank, dtype=torch.bool, pin_memory=True)
302303

@@ -478,7 +479,9 @@ def run(
478479
spec_decoding_tensor_params.append(self.spec_decoding_bl_tree_mask)
479480
spec_decoding_tensor_params.append(
480481
self.spec_bl_tree_first_sparse_mask_offset_kv)
481-
mla_tensor_params = [self.helix_position_offsets, self.helix_is_inactive_rank]
482+
mla_tensor_params = [
483+
self.helix_position_offsets, self.helix_is_inactive_rank
484+
]
482485

483486
thop.attention(
484487
q,
@@ -855,10 +858,8 @@ def prepare(self) -> None:
855858
if self.helix_is_inactive_rank is not None and len(
856859
self.helix_is_inactive_rank):
857860
# If helix is inactive, attend to the previously cached tokens only.
858-
# This gets further complicated with multiple requests as each request might
859-
# have a different active helix rank.
860861
assert cached_token_lens is not None, "cached_token_lens should be set for helix"
861-
kv_lens = cached_token_lens
862+
kv_lens = cached_token_lens.clone()
862863
helix_is_inactive_rank_cpu = torch.tensor(
863864
self.helix_is_inactive_rank,
864865
dtype=torch.bool,
@@ -1768,15 +1769,18 @@ def mla_rope_generation(
17681769
assert self.is_mla_enable and self.mla_params is not None
17691770
assert metadata.kv_cache_manager is not None
17701771
sink_token_length = 0
1771-
1772+
17721773
# Ensure helix_is_inactive_rank is on the same device as other tensors
17731774
if helix_is_inactive_rank is not None:
17741775
if isinstance(helix_is_inactive_rank, list):
17751776
helix_is_inactive_rank = torch.tensor(
1776-
helix_is_inactive_rank, dtype=torch.bool, device=helix_position_offsets.device)
1777+
helix_is_inactive_rank,
1778+
dtype=torch.bool,
1779+
device=helix_position_offsets.device)
17771780
elif helix_is_inactive_rank.device.type != 'cuda':
1778-
helix_is_inactive_rank = helix_is_inactive_rank.to(helix_position_offsets.device)
1779-
1781+
helix_is_inactive_rank = helix_is_inactive_rank.to(
1782+
helix_position_offsets.device)
1783+
17801784
mla_tensor_params = [helix_position_offsets, helix_is_inactive_rank]
17811785

17821786
torch.ops.trtllm.mla_rope_generation(

tensorrt_llm/_torch/distributed/communicator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
import copy
12
import math
23
import pickle # nosec B403
34
from abc import ABC, abstractmethod
45
from functools import wraps
56
from typing import Optional
67

7-
import copy
88
import numpy as np
99
import torch
1010
import torch.distributed as dist
@@ -346,7 +346,8 @@ def __init__(self, mapping: Mapping):
346346
# Repurpose CP ranks to TP for Helix so that the right comms are created.
347347
mapping_with_helix = None
348348
if self.mapping.cp_size > 1:
349-
print(f"[MPIDist::__init__] Repurposing CP ranks to TP for Helix.")
349+
logger.info(
350+
f"[MPIDist::__init__] Repurposing CP ranks to TP for Helix.")
350351
mapping_with_helix = copy.deepcopy(self.mapping)
351352
mapping_without_helix = Mapping(
352353
world_size=self.mapping.world_size,
@@ -364,7 +365,9 @@ def __init__(self, mapping: Mapping):
364365

365366
# Restore the original mapping.
366367
if mapping_with_helix is not None:
367-
print(f"[MPIDist::__init__] Restoring original mapping.")
368+
logger.info(
369+
f"[MPIDist::__init__] Restoring original mapping undoing Helix manipulation."
370+
)
368371
self.mapping = mapping_with_helix
369372

370373
def broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -545,8 +545,6 @@ def __init__(
545545
config=model_config,
546546
aux_stream=aux_stream,
547547
mapping_with_cp=mapping_with_cp)
548-
# @B: Does this layer need to know about mapping_with_cp?
549-
# Likely no because no use of mapping.
550548
self.kv_a_proj_with_mqa = DeepseekV3Linear(
551549
config.hidden_size,
552550
self.kv_lora_rank + self.qk_rope_head_dim +

tensorrt_llm/_torch/modules/attention.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,8 @@ def __init__(
750750
# tensor parallel
751751
config = config or ModelConfig()
752752
if mapping_with_cp is not None:
753-
print("[MLA::__init__] OVERRIDING MAPPING WITH CP DETECTED.")
753+
logger.warning(
754+
"[MLA::__init__] Overriding mapping with CP detected.")
754755
self.mapping = mapping_with_cp
755756
else:
756757
self.mapping = config.mapping
@@ -762,7 +763,8 @@ def __init__(
762763
if self.mapping.has_cp_ulysses():
763764
raise NotImplementedError("MLA doesn't support CP Ulyssees yet")
764765
if self.mapping.cp_size > 1:
765-
assert self.mapping.cp_config['cp_type'] == CpType.HELIX, f"CP type must be HELIX for MLA, but got {self.mapping.cp_config['cp_type']}."
766+
assert self.mapping.cp_config[
767+
'cp_type'] == CpType.HELIX, f"CP type must be HELIX for MLA, but got {self.mapping.cp_config['cp_type']}."
766768

767769
mapping = Mapping(
768770
world_size=tp_size * pp_size * cp_size,
@@ -1727,20 +1729,19 @@ def forward_absorption_generation(
17271729
maybe_execute_in_parallel(
17281730
lambda: torch.ops.trtllm.bmm_out(
17291731
q_nope_t, self.k_b_proj_trans.transpose(1, 2), q_nope_out),
1730-
lambda: self.mqa.mla_rope_generation(fused_q,
1731-
q_pe,
1732-
latent_cache,
1733-
attn_metadata,
1734-
cu_q_seqlens,
1735-
cu_kv_seqlens,
1736-
fmha_scheduler_counter,
1737-
mla_bmm1_scale,
1738-
mla_bmm2_scale,
1739-
quant_q_buffer,
1740-
helix_position_offsets=
1741-
helix_position_offsets,
1742-
helix_is_inactive_rank=
1743-
helix_is_inactive_rank),
1732+
lambda: self.mqa.mla_rope_generation(
1733+
fused_q,
1734+
q_pe,
1735+
latent_cache,
1736+
attn_metadata,
1737+
cu_q_seqlens,
1738+
cu_kv_seqlens,
1739+
fmha_scheduler_counter,
1740+
mla_bmm1_scale,
1741+
mla_bmm2_scale,
1742+
quant_q_buffer,
1743+
helix_position_offsets=helix_position_offsets,
1744+
helix_is_inactive_rank=helix_is_inactive_rank),
17441745
self.ln_events[0],
17451746
self.ln_events[1],
17461747
rope_stream,
@@ -1758,20 +1759,19 @@ def forward_absorption_generation(
17581759
q_nope_out,
17591760
self.k_b_proj_trans_dequant,
17601761
),
1761-
lambda: self.mqa.mla_rope_generation(fused_q,
1762-
q_pe,
1763-
latent_cache,
1764-
attn_metadata,
1765-
cu_q_seqlens,
1766-
cu_kv_seqlens,
1767-
fmha_scheduler_counter,
1768-
mla_bmm1_scale,
1769-
mla_bmm2_scale,
1770-
quant_q_buffer,
1771-
helix_position_offsets=
1772-
helix_position_offsets,
1773-
helix_is_inactive_rank=
1774-
helix_is_inactive_rank),
1762+
lambda: self.mqa.mla_rope_generation(
1763+
fused_q,
1764+
q_pe,
1765+
latent_cache,
1766+
attn_metadata,
1767+
cu_q_seqlens,
1768+
cu_kv_seqlens,
1769+
fmha_scheduler_counter,
1770+
mla_bmm1_scale,
1771+
mla_bmm2_scale,
1772+
quant_q_buffer,
1773+
helix_position_offsets=helix_position_offsets,
1774+
helix_is_inactive_rank=helix_is_inactive_rank),
17751775
self.ln_events[0],
17761776
self.ln_events[1],
17771777
rope_stream,

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -710,8 +710,7 @@ def _merge_requests(
710710
elif cp_type == CpType.HELIX:
711711
return self._merge_helix_requests(
712712
new_requests,
713-
tokens_per_block=32)
714-
# tokens_per_block=cp_config['tokens_per_block'])
713+
tokens_per_block=cp_config['tokens_per_block'])
715714
else:
716715
raise NotImplementedError(
717716
f'Unsupported cp type {cp_type.name}.')

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from abc import ABC, abstractmethod
1010
from contextlib import contextmanager
1111
from typing import Any, Callable, Dict, List, Optional, Tuple
12-
from .llm_request import LlmRequest
1312

1413
import torch
1514
import torch._dynamo.config
@@ -566,6 +565,9 @@ def warmup(self, resource_manager: ResourceManager) -> None:
566565
cp_type = self.mapping.cp_config.get('cp_type', None)
567566
if cp_type is not None:
568567
if cp_type in [CpType.ULYSSES, CpType.STAR]:
568+
logger.info(
569+
"[ModelEngine::warmup] Skipping warmup for cp_type: ",
570+
cp_type.name)
569571
return
570572

571573
self._run_torch_compile_warmup(resource_manager)
@@ -1620,7 +1622,8 @@ def _prepare_tp_inputs(
16201622
request.cached_tokens = num_cached_tokens_per_seq[-1]
16211623
prompt_lengths.append(request.py_prompt_len)
16221624
if self.mapping.has_cp_helix():
1623-
helix_is_inactive_rank.append(request.py_helix_is_inactive_rank)
1625+
helix_is_inactive_rank.append(
1626+
request.py_helix_is_inactive_rank)
16241627
draft_lens.append(0)
16251628
sequence_lengths.append(1)
16261629
num_accepted_draft_tokens.append(0)

tensorrt_llm/commands/serve.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import gc
33
import json
44
import os
5-
import gc
65
import signal # Added import
76
import subprocess # nosec B404
87
import sys
@@ -131,6 +130,9 @@ def get_llm_args(
131130
except KeyError:
132131
raise ValueError(f"Invalid cp_type: {cp_config['cp_type']}. " \
133132
f"Must be one of: {', '.join([t.name for t in CpType])}")
133+
if cp_config["cp_type"] == CpType.HELIX:
134+
cp_config['tokens_per_block'] = kv_cache_config.tokens_per_block
135+
134136
llm_args = {
135137
"model": model,
136138
"scheduler_config": scheduler_config,
@@ -386,8 +388,8 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
386388
def serve(
387389
model: str, tokenizer: Optional[str], host: str, port: int,
388390
log_level: str, backend: str, max_beam_width: int, max_batch_size: int,
389-
max_num_tokens: int, max_seq_len: int, tp_size: int, pp_size: int, cp_size: int,
390-
ep_size: Optional[int], cluster_size: Optional[int],
391+
max_num_tokens: int, max_seq_len: int, tp_size: int, pp_size: int,
392+
cp_size: int, ep_size: Optional[int], cluster_size: Optional[int],
391393
gpus_per_node: Optional[int], kv_cache_free_gpu_memory_fraction: float,
392394
num_postprocess_workers: int, trust_remote_code: bool,
393395
extra_llm_api_options: Optional[str], reasoning_parser: Optional[str],

0 commit comments

Comments
 (0)