Skip to content

Commit 6e5fe5d

Browse files
committed
updates to dsv3RopeOp
1 parent 079abdb commit 6e5fe5d

File tree

4 files changed

+18
-25
lines changed

4 files changed

+18
-25
lines changed

cpp/tensorrt_llm/thop/dsv3RopeOp.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ struct MlaRopeGenArgs
6666
float const* kv_scale_quant_orig_ptr;
6767
float host_bmm1_scale;
6868
int32_t const* helix_position_offsets_ptr;
69+
bool const* helix_is_inactive_rank_ptr;
6970
};
7071

7172
template <typename T, typename KVCacheBuffer>
@@ -105,6 +106,7 @@ void invokeMLARopeGenerationHelper(T const* latent_cache_ptr, T* q_pe_ptr, T* fu
105106
mla_params.dequant_scale_kv = args.kv_scale_quant_orig_ptr;
106107
mla_params.host_bmm1_scale = args.host_bmm1_scale;
107108
mla_params.helix_position_offsets = args.helix_position_offsets_ptr;
109+
mla_params.helix_is_inactive_rank = args.helix_is_inactive_rank_ptr;
108110

109111
tk::invokeMLARopeGeneration<T>(mla_params, kv_cache_buffer, stream);
110112
}
@@ -134,7 +136,7 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim +
134136
head_size == kv_lora_rank + qk_rope_head_dim, "head_size must = kv_lora_rank + qk_rope_head_dim");
135137
TLLM_CHECK_WITH_INFO(num_kv_heads == 1, "num_kv_heads must = 1");
136138
TORCH_CHECK(
137-
mla_tensor_params.size() == 1, "Expecting 1 tensor for custom MLA tensor params: helix_position_offsets.");
139+
mla_tensor_params.size() == 2, "Expecting 2 tensors for custom MLA tensor params: helix_position_offsets and helix_is_inactive_rank.");
138140

139141
auto stream = at::cuda::getCurrentCUDAStream(fused_q.get_device());
140142
auto const kv_cache_quant_mode = tc::QuantMode(uint32_t(quant_mode));
@@ -153,6 +155,7 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim +
153155
int32_t const num_gen_tokens = num_tokens;
154156
int32_t const seq_offset = num_contexts;
155157
auto& mla_helix_position_offsets = mla_tensor_params[0];
158+
auto& mla_helix_is_inactive_rank = mla_tensor_params[1];
156159
int32_t const layer_num = host_kv_cache_pool_mapping.value().size(0);
157160

158161
tk::MlaMetaParams mla_meta_params = {static_cast<int>(q_lora_rank), static_cast<int>(kv_lora_rank),
@@ -161,6 +164,8 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim +
161164

162165
int32_t const* helix_position_offsets_ptr
163166
= mla_helix_position_offsets.has_value() ? mla_helix_position_offsets->data_ptr<int32_t>() : nullptr;
167+
bool const* helix_is_inactive_rank_ptr
168+
= mla_helix_is_inactive_rank.has_value() ? mla_helix_is_inactive_rank->data_ptr<bool>() : nullptr;
164169

165170
int* cu_q_seqlens_ptr = reinterpret_cast<int*>(cu_q_seqlens.data_ptr());
166171
int* cu_kv_seqlens_ptr = reinterpret_cast<int*>(cu_kv_seqlens.data_ptr());
@@ -274,7 +279,7 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim +
274279
static_cast<int32_t>(num_heads), mla_meta_params, sequence_lengths_ptr, max_context_q_len,
275280
block_ids_per_seq_ptr, cache_type, cu_q_seqlens_ptr, cu_kv_seqlens_ptr, fmha_tile_counter_ptr,
276281
mla_bmm1_scale_ptr, mla_bmm2_scale_ptr, quant_q_buffer_ptr, quant_scale_o_ptr, kv_scale_orig_quant_ptr,
277-
kv_scale_quant_orig_ptr, host_bmm1_scale, helix_position_offsets_ptr};
282+
kv_scale_quant_orig_ptr, host_bmm1_scale, helix_position_offsets_ptr, helix_is_inactive_rank_ptr};
278283

279284
auto const input_dtype = fused_q.scalar_type();
280285
if (input_dtype == torch::kFloat16)

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@
7070
from .modeling_utils import (DecoderModel, EagerFusionConfig, filter_weights,
7171
register_auto_model)
7272

73-
# from ..utils import use_torch_printoptions
74-
7573

7674
@triton.jit
7775
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from tensorrt_llm.mapping import CpType
1515

1616
from ..distributed import Distributed
17-
from ..utils import use_torch_printoptions
1817
from .llm_request import (ExecutorRequest, LlmRequest,
1918
executor_request_to_llm_request)
2019

@@ -694,16 +693,12 @@ def _merge_helix_requests(self, new_requests: list[RequestQueueItem],
694693
input_ids_this_rank = input_ids_this_rank[:-padding_len]
695694
position_ids_this_rank = position_ids_this_rank[:-padding_len]
696695

697-
with use_torch_printoptions(sci_mode=False,
698-
threshold=16,
699-
edgeitems=2,
700-
linewidth=120):
701-
print(
702-
f"[ExecutorRequestQueue::_merge_helix_requests][{curr_cp_rank}]: input_ids_this_rank: {torch.tensor(input_ids_this_rank)}"
703-
)
704-
print(
705-
f"[ExecutorRequestQueue::_merge_helix_requests][{curr_cp_rank}]: position_ids_this_rank: {torch.tensor(position_ids_this_rank)}"
706-
)
696+
print(
697+
f"[ExecutorRequestQueue::_merge_helix_requests][{curr_cp_rank}]: input_ids_this_rank: {torch.tensor(input_ids_this_rank)}"
698+
)
699+
print(
700+
f"[ExecutorRequestQueue::_merge_helix_requests][{curr_cp_rank}]: position_ids_this_rank: {torch.tensor(position_ids_this_rank)}"
701+
)
707702
# TODO: Figure how to pass down position_ids_this_rank to LLMRequest.
708703
req = executor_request_to_llm_request(
709704
req_id=req_item.id,

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
from ..speculative.drafter import Drafter
4343
from ..speculative.mtp import SampleStateTensorsMTP
4444
from ..speculative.speculation_gate import SpeculationGate
45-
from ..utils import use_torch_printoptions
4645
from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem
4746
from .guided_decoder import GuidedDecoder
4847
from .handle_additional_outputs import HandleAdditionalOutputs
@@ -1898,15 +1897,11 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch):
18981897
req.py_draft_tokens = [] if ctx_draft_tokens is None else ctx_draft_tokens
18991898
beam_width = req.sampling_config.beam_width
19001899

1901-
with use_torch_printoptions(sci_mode=False,
1902-
threshold=16,
1903-
edgeitems=2,
1904-
linewidth=120):
1905-
for beam in range(0, beam_width):
1906-
print(
1907-
f"[PyExecutor::_prepare_disagg_gen_transmission_complete]: Adding new token {torch.tensor(first_gen_tokens[beam])} for beam {beam}."
1908-
)
1909-
req.add_new_token(first_gen_tokens[beam], beam)
1900+
for beam in range(0, beam_width):
1901+
print(
1902+
f"[PyExecutor::_prepare_disagg_gen_transmission_complete]: Adding new token {torch.tensor(first_gen_tokens[beam])} for beam {beam}."
1903+
)
1904+
req.add_new_token(first_gen_tokens[beam], beam)
19101905

19111906
@nvtx_range("_recv_disagg_gen_cache")
19121907
def _recv_disagg_gen_cache(self, new_gen_reqs):

0 commit comments

Comments
 (0)