diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index b9b1310cbe9..8664c53fffa 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -406,20 +406,20 @@ def forward( - KV cache: (ABCD) + EFGH (H's KV cache is invalid) - hidden states: H_E, H_F, H_G, H_H (H_H is invalid) Draft model: - MPT1: + MTP1: # For generation request, `mtp_num_modules` of tokens will be used as input. - input tokens: FGX - input hidden states: H_E, H_F, H_G - KV cache: (BCDE) + FGX - output hidden states: h_F, h_G, h_X - output next draft token: N - MPT2: + MTP2: - input tokens: GXN - input hidden states: H_F, H_G, h_X - KV cache: (CDEF) + GXN - output hidden states: h_G, h_X, h_N - output next draft token: O - MPT3: + MTP3: - input tokens: XNO - input hidden states: H_G, H_X, h_N - KV cache: (DEFG) + XNO