feat: support chat datasets with THD, BSHD + CP and padding fixes#1416
Open
hemildesai wants to merge 27 commits intomainfrom
Open
feat: support chat datasets with THD, BSHD + CP and padding fixes#1416hemildesai wants to merge 27 commits intomainfrom
hemildesai wants to merge 27 commits intomainfrom
Conversation
Contributor
Author
|
/ok to test d8d2063 |
4c885c3 to
50107e3
Compare
Contributor
Author
|
/ok to test 3b9b22d |
Revert num_label_tokens guards in MaskedCrossEntropy and scale_grads_and_clip_grad_norm and their corresponding tests. These changes will be addressed separately. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Hemil Desai <hemild@nvidia.com>
Explain why padding is disabled in the first apply_chat_template call and deferred to _package_tokenized_example where input_ids, labels, and attention_mask are padded consistently. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Hemil Desai <hemild@nvidia.com>
Load the full base split, shuffle with a fixed seed (42), then apply any slice (e.g. "train[1024:]") so that train/val splits sample from a consistent random order instead of taking contiguous chunks. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Hemil Desai <hemild@nvidia.com>
Add shuffle_seed parameter to _load_openai_messages and ChatDataset. Defaults to 42 for backwards compatibility. Set to None to disable shuffling entirely. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Hemil Desai <hemild@nvidia.com>
Change default shuffle_seed from 42 to None (no shuffle) so datasets are not shuffled unless explicitly requested. Add shuffle_seed: 42 to the qwen3_moe_30b_te_chat_thd.yaml example config. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Hemil Desai <hemild@nvidia.com>
padding_mask is collater/CP metadata — not a model input. MoE models (gpt_oss, qwen3_moe, deepseek_v3, etc.) already derive it internally from attention_mask when padding_mask is absent, so it is safe and necessary (for HF models like GPT2) to always strip it here. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Hemil Desai <hemild@nvidia.com>
GPT2LMHeadModel was the only model in the repo whose forward() did not accept extra keyword arguments. This caused a TypeError when the collater includes padding_mask in the batch dict. Adding **kwargs lets it absorb padding_mask (and any future metadata) like every other model. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Hemil Desai <hemild@nvidia.com>
Now that GPT2LMHeadModel accepts **kwargs, every model in the repo can handle padding_mask. Removing the pop lets MoE models receive it directly instead of re-deriving it from attention_mask on the GPU. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Hemil Desai <hemild@nvidia.com>
When split is None, pass it through to load_dataset as-is instead of overriding it to "train". Skip the slice regex match when split is None. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Hemil Desai <hemild@nvidia.com>
format_chat_template now uses padding=False, so the pre-padded guard in _package_tokenized_example is only needed for the format_prompt_response_example path. Update the comment accordingly. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Hemil Desai <hemild@nvidia.com>
The function is format_prompt_completion, not format_prompt_response_example. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Hemil Desai <hemild@nvidia.com>
When apply_chat_template truncates to max_length (seq_length), appending EOS makes the total seq_length+1. After BOS removal in _package_tokenized_example, labels are exactly seq_length with no room for -100 padding — the last label becomes a spurious EOS instead of -100. Skip the EOS append when len(input_ids) >= seq_length, indicating truncation occurred and the sequence was cut mid-conversation. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Hemil Desai <hemild@nvidia.com>
Add _resolve_chat_template to formatting_utils that accepts a file path, JSON file (extracting the "chat_template" key), or literal Jinja string. Wire it into ChatDataset.__init__ so users can pass a path to a template file instead of an inline string. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Hemil Desai <hemild@nvidia.com>
- Pass the user's padding value through to apply_chat_template instead of hardcoding padding=False; use the tokenizer's attention_mask to zero out padding positions in the loss mask. - Remove manual EOS append after apply_chat_template; the chat template itself is responsible for EOS. - In ChatDataset, resolve padding="longest" at init by batching all conversations through apply_chat_template(padding="longest") to find the max tokenized length, then use "max_length" per sample. - Remove dead code: _JINJA_CHARS, unreachable attention_mask zeroing block, and _StubTokenizerChatNoGenWithPadding test stub. - Fix _StubTokenizerChatNoGen to append EOS like real chat templates. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Hemil Desai <hemild@nvidia.com>
Let padding="longest" pass through to the collator for batch-level padding instead of resolving it at dataset init. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Hemil Desai <hemild@nvidia.com>
Remove attention_mask handling from _make_cp_batch_for_te and _shard_thd_chunk_for_te since THD format does not use attention_mask. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Hemil Desai <hemild@nvidia.com>
Compute content_length on the original input_ids before the [:-1] next-token prediction shift, then subtract 1. This ensures padded and non-padded inputs produce identical attention masks — previously the shift removed a pad token (padded) vs a real token (non-padded), causing an off-by-one in the attention mask. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Hemil Desai <hemild@nvidia.com>
Remove assertions that the last real input_id is not EOS — this was only true when EOS was manually appended after apply_chat_template. Now apply_chat_template adds EOS via the template itself, so with padding the shift removes a pad token and EOS stays in input_ids. Add stronger invariant checks: - attention_mask=0 positions must have labels=-100 - attention_mask must be contiguous (ones then zeros, right-padded) - padding region input_ids must use pad_token_id Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Hemil Desai <hemild@nvidia.com>
Contributor
Author
|
/ok to test 503f142 |
With seq_length=4 and truncation=True, the sequence is truncated to just the start of the system message — no assistant tokens survive. Skip the "must have supervised tokens" assertion when truncation is enabled since it may legitimately remove all assistant content. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Hemil Desai <hemild@nvidia.com>
When a chat template is used, the template's own turn-ending token (e.g. <|im_end|>) terminates sequences — not eos_token_id (e.g. <|endoftext|>). These are different tokens for Qwen2.5/3. Remove assertions that eos_token_id must appear in supervised labels since it was only true when EOS was manually appended. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Hemil Desai <hemild@nvidia.com>
Contributor
Author
|
/ok to test 155c8b2 |
Contributor
Author
|
/ok to test bc6e83c |
Contributor
Author
|
/ok to test 6f1a31e |
Signed-off-by: hemildesai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: hemildesai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: hemildesai <hemild@nvidia.com>
Contributor
Author
|
/ok to test e8e61d5 |
Contributor
|
/ok to test 035bc9d |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
packed_sequence_thd_collaterto handle non-packed (ChatDataset) data by synthesizingseq_lens,seq_lens_padded, andposition_idsmetadata, enabling THD format for chat-style datasets.paddingvalue through toapply_chat_templateinstead of hardcodingpadding=False. Use the tokenizer'sattention_maskto zero out padding positions in the loss mask — in bothformat_chat_templateandformat_prompt_completion. Foranswer_only_loss, tokenize the prompt without padding to get its real length, then derive the mask from the length difference against the (possibly padded) full text._resolve_chat_templatetoformatting_utilssochat_templatecan be a file path (plain text or JSON with"chat_template"key) or a literal Jinja string.eos_token_idappend afterapply_chat_template._package_tokenized_exampleto strip trailing pad tokens before building the attention mask, so inputs that arrive already padded get proper zero-masking.train[1024:]), so that train/val splits sample from a consistent random order instead of taking contiguous chunks.padding_masktodefault_collateroutput and propagatepadding_maskthrough Context Parallelism buffers (cp_utils.py).collate_fnto_target_syntax inllama3_2_1b_squad.yamland addpad_seq_len_divisible 512viadataloader.collate_fnin DCP PP2 test scripts so all microbatches have uniform sequence length.qwen3_moe_30b_te_chat_thd.yamlfor Qwen3 MoE 30B with ChatDataset + THD collater.Test plan
_resolve_chat_template(file path, JSON, literal string, None, nonexistent path)TestPackageTokenizedExamplePrePaddedInput)test_utils.py)default_collaterpadding_maskoutputformat_chat_templateattention_mask zeroing at padding positionsformat_prompt_completionattention_mask zeroing at padding positionspadding_maskpropagation_package_tokenized_examplecontent_length branches (TestContentLengthBranches)pad_seq_len_divisiblevia collate_fn + conftest override registration🤖 Generated with Claude Code