-
Notifications
You must be signed in to change notification settings - Fork 5
support mtp stage 2: support deepseek mtp=1 #35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements stage 1 support for multiple token processing (MTP) by enabling the model engine to handle multiple decode tokens per step. The changes modify the core engine to track and process variable numbers of tokens per sequence rather than assuming exactly one token per decode step.
Key changes:
- Modified the postprocessing pipeline to return multiple tokens per sequence (
dict[int, list[int]]instead ofdict[int, int]) - Enhanced block allocation to reserve space for lookahead slots when in deferred output mode
- Updated position and context length calculations to account for previously accepted tokens in deferred output mode
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| atom/model_engine/scheduler.py | Updated postprocess to handle multiple tokens per sequence, modified block allocation to reserve lookahead slots, and removed placeholder token logic |
| atom/model_engine/model_runner.py | Changed return types from single token to token lists, added tracking of accepted token counts per sequence |
| atom/model_engine/block_manager.py | Added num_lookahead_slots parameter to allocation methods to reserve space for multiple tokens |
| atom/model_ops/attentions/backends.py | Updated block table preparation to use actual context lengths when in deferred output mode |
| atom/model_ops/attentions/aiter_mla.py | Added logic to calculate actual context lengths and positions based on deferred tokens from previous batch |
| atom/model_ops/attentions/aiter_attention.py | Similar context length calculation updates for standard attention mechanism |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 8 out of 9 changed files in this pull request and generated 7 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ) | ||
| ] | ||
| if max_q_len > 1: | ||
| context_lens = np.array(batch.context_lens) - var["mtp_k"] |
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The computation 'batch.context_lens - var["mtp_k"]' is performed but the result overwrites context_lens which is also used in the else block at line 195. This could be confusing. Consider using a more descriptive variable name like 'adjusted_context_lens' for clarity.
atom/model_engine/scheduler.py
Outdated
| num_placeholder = ( | ||
| 2 * self.mtp_k if is_deferred_out and self.use_spec else | ||
| 1 if is_deferred_out else | ||
| self.mtp_k if self.use_spec else | ||
| 0 | ||
| ) |
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The nested ternary expression for num_placeholder is difficult to read. Consider using an if-elif-else structure or extracting this logic into a separate method with a descriptive name for better maintainability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 8 out of 9 changed files in this pull request and generated 8 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| runner=None, | ||
| ): | ||
| self.config = atom_config | ||
| self.speculative_config = self.config.speculative_config |
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable mtp_k is renamed but its meaning remains ambiguous. Based on the context, this represents the number of speculative tokens, not k in the traditional MTP sense. Consider renaming to num_speculative_tokens for clarity, or add a comment explaining the naming convention.
| self.speculative_config = self.config.speculative_config | |
| self.speculative_config = self.config.speculative_config | |
| # Number of speculative tokens used in MTP; kept as `mtp_k` for backward compatibility. |
| # Generate the remaining draft tokens. | ||
| draft_token_ids_list = [draft_token_ids] | ||
|
|
||
| for _ in range(self.mtp_k - 1): | ||
| pass # TODO: support multiple num_speculative_tokens | ||
|
|
||
| # [batch_size, num_speculative_tokens] | ||
| draft_token_ids = torch.stack(draft_token_ids_list, dim=1) | ||
| return draft_token_ids | ||
|
|
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TODO comment is misleading - the function already handles mtp_k speculative tokens via the loop. Either implement the remaining logic or update the comment to clarify what specific functionality is missing.
| # Generate the remaining draft tokens. | |
| draft_token_ids_list = [draft_token_ids] | |
| for _ in range(self.mtp_k - 1): | |
| pass # TODO: support multiple num_speculative_tokens | |
| # [batch_size, num_speculative_tokens] | |
| draft_token_ids = torch.stack(draft_token_ids_list, dim=1) | |
| return draft_token_ids | |
| # Multiple draft tokens (mtp_k > 1) are not yet supported. Fail fast instead of | |
| # returning an incorrectly shaped tensor or silently misbehaving. | |
| raise NotImplementedError( | |
| f"Multiple speculative tokens (mtp_k={self.mtp_k}) are not supported in EagleSpecDecoder yet." | |
| ) |
| metadata.num_draft_tokens, | ||
| metadata.max_spec_len, | ||
| metadata.cu_num_draft_tokens, | ||
| None, |
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The draft_probs parameter is always passed as None but still has validation logic and is used in assertions. Either implement probability-based rejection sampling or remove this unused parameter and its associated validation code.
| num_placeholder = ( | ||
| 2 * self.mtp_k if is_deferred_out and self.use_spec else | ||
| 1 if is_deferred_out else | ||
| self.mtp_k if self.use_spec else | ||
| 0 | ||
| ) |
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The multi-line conditional expression for num_placeholder is difficult to read and error-prone. Consider using a traditional if-elif-else structure or extracting this logic into a helper method with clear comments explaining each case.
| num_placeholder = ( | |
| 2 * self.mtp_k if is_deferred_out and self.use_spec else | |
| 1 if is_deferred_out else | |
| self.mtp_k if self.use_spec else | |
| 0 | |
| ) | |
| # Determine how many placeholder tokens to use based on deferred output | |
| # and speculative decoding configuration. | |
| if is_deferred_out and self.use_spec: | |
| num_placeholder = 2 * self.mtp_k | |
| elif is_deferred_out: | |
| num_placeholder = 1 | |
| elif self.use_spec: | |
| num_placeholder = self.mtp_k | |
| else: | |
| num_placeholder = 0 |
| else: | ||
| if not seq.ignore_eos and token_id == self.eos_token_id: | ||
| # Check the last token in the list for EOS | ||
| if token_ids and not seq.ignore_eos and self.eos_token_id in token_ids: |
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The EOS check now looks for eos_token_id anywhere in the token_ids list, but the comment on line 236 says 'Check the last token in the list for EOS'. This could cause premature sequence termination if EOS appears in draft tokens. The check should verify only the last accepted token: token_ids[-1] == self.eos_token_id.
| if token_ids and not seq.ignore_eos and self.eos_token_id in token_ids: | |
| if token_ids and not seq.ignore_eos and token_ids[-1] == self.eos_token_id: |
| """for decode: input ids are from prev_sampled_token_ids""" | ||
| locations, is_all_alive = self.get_prev_alive_locations(batch) | ||
| num_deferred_tokens = len(locations) | ||
| alive_seq_indices, num_deferred_tokens, is_all_alive = self.get_prev_alive_locations(batch) |
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable alive_seq_indices is misleading - it contains indices into the previous batch, not sequence IDs. Consider renaming to prev_batch_indices or deferred_seq_indices to clarify what these indices reference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 10 out of 11 changed files in this pull request and generated 8 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| self.speculative_config = self.config.speculative_config | ||
| self.num_speculative_tokens = self.speculative_config.num_speculative_tokens | ||
| self.mtp_k = self.num_speculative_tokens + 1 | ||
| self.mtp_k = self.speculative_config.num_speculative_tokens |
Copilot
AI
Jan 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable mtp_k is initialized differently than in the original code. Previously it was num_speculative_tokens + 1, now it's just num_speculative_tokens. Ensure this change is intentional and verify all usages of mtp_k account for this modification, particularly in early exit conditions and loop bounds.
| self.mtp_k = self.speculative_config.num_speculative_tokens | |
| self.mtp_k = self.speculative_config.num_speculative_tokens + 1 |
| new_tokens = [] | ||
| if is_deferred_out: | ||
| seq.token_ids[-1] = token_id | ||
| if is_deferred_out or (self.use_spec and self.eos_token_id == seq.token_ids[-1]): |
Copilot
AI
Jan 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition combines two distinct concerns. The second part checking for EOS token during spec decode is unclear. Consider adding a comment explaining why EOS tokens trigger this branch or extracting to a named variable like is_spec_with_eos.
| if is_deferred_out or (self.use_spec and self.eos_token_id == seq.token_ids[-1]): | |
| # During speculative decoding, if the last committed token is EOS, we also | |
| # need to replace the placeholder tokens with the sampled tokens. | |
| is_spec_with_eos = self.use_spec and self.eos_token_id == seq.token_ids[-1] | |
| if is_deferred_out or is_spec_with_eos: |
| class tokenIDProcessor: | ||
|
|
||
| def __init__(self, max_num_batched_tokens: int, device: torch.device): | ||
| def __init__(self, max_num_batched_tokens: int, device: torch.device, use_spec: bool, num_speculative_tokens): |
Copilot
AI
Jan 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter num_speculative_tokens lacks a type annotation. Add : int for consistency with other parameters.
| def __init__(self, max_num_batched_tokens: int, device: torch.device, use_spec: bool, num_speculative_tokens): | |
| def __init__(self, max_num_batched_tokens: int, device: torch.device, use_spec: bool, num_speculative_tokens: int): |
| if self.config.speculative_config and get_pp_group().is_last_rank: | ||
| self.drafter = EagleProposer(self.config, self.device, self) | ||
| self.rejection_sampler = RejectionSampler() | ||
| num_speculative_tokens = self.drafter.mtp_k if hasattr(self, "drafter") else 0 |
Copilot
AI
Jan 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The initialization order is fragile - self.drafter and self.rejection_sampler are conditionally created, then immediately checked with hasattr(self, "drafter"). Consider extracting the speculative token count logic into the conditional block where drafter is created to make the dependency explicit.
| num_rejected_tokens = [ | ||
| n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 | ||
| for i, n in enumerate(num_draft_tokens) | ||
| ] |
Copilot
AI
Jan 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The list comprehension assumes sampled_token_ids keys match the enumeration indices of num_draft_tokens, but sampled_token_ids is a dictionary keyed by sequence IDs. This will cause a KeyError if the sequence IDs don't align with 0-indexed positions. Access sampled_token_ids using the actual sequence IDs from the batch instead.
| num_rejected_tokens = [ | |
| n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 | |
| for i, n in enumerate(num_draft_tokens) | |
| ] | |
| # Use actual sequence IDs from the batch to index sampled_token_ids | |
| sequence_ids = [seq.seq_id for seq in batch.sequences] | |
| num_rejected_tokens: list[int] = [] | |
| for seq_id, n in zip(sequence_ids, num_draft_tokens): | |
| if n > 0: | |
| sampled = sampled_token_ids.get(seq_id, []) | |
| num_rejected_tokens.append(n + 1 - len(sampled)) | |
| else: | |
| num_rejected_tokens.append(0) |
| ] | ||
| if max_q_len > 1: | ||
| context_lens = np.array(batch.context_lens) - var["mtp_k"] | ||
| positions = [pos for le in context_lens for pos in (le - var["mtp_k"], le)] |
Copilot
AI
Jan 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable name le is ambiguous. Use a more descriptive name like context_len to improve readability.
| positions = [pos for le in context_lens for pos in (le - var["mtp_k"], le)] | |
| positions = [ | |
| pos | |
| for context_len in context_lens | |
| for pos in (context_len - var["mtp_k"], context_len) | |
| ] |
| def may_append(self, seq: Sequence): | ||
| block_table = seq.block_table | ||
| last_block = self.blocks[block_table[-1]] | ||
| if len(seq) % self.block_size == 1 or self.block_size == 1: |
Copilot
AI
Jan 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation (len(seq) + seq.block_size - 1) // seq.block_size is the standard ceiling division pattern but could be more explicit. Consider adding a comment explaining this calculates the required number of blocks, or extract to a variable like total_blocks_needed.
| if len(seq) % self.block_size == 1 or self.block_size == 1: | |
| if len(seq) % self.block_size == 1 or self.block_size == 1: | |
| # Compute the total number of blocks required to store the current sequence (ceiling division). |
| self.num_tokens += 1 | ||
|
|
||
| def __repr__(self): | ||
| return f"Sequence ({self.id=}, {self.token_ids=}, {self.spec_token_ids})" |
Copilot
AI
Jan 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __repr__ method uses f-string debug format for id and token_ids (with =) but not for spec_token_ids. For consistency, either use {self.spec_token_ids=} or remove the = from the other fields.
| return f"Sequence ({self.id=}, {self.token_ids=}, {self.spec_token_ids})" | |
| return f"Sequence ({self.id=}, {self.token_ids=}, {self.spec_token_ids=})" |
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist