Skip to content

feat: support chat datasets with THD, BSHD + CP and padding fixes#1416

Open
hemildesai wants to merge 27 commits intomainfrom
hemil/data-fixes
Open

feat: support chat datasets with THD, BSHD + CP and padding fixes#1416
hemildesai wants to merge 27 commits intomainfrom
hemil/data-fixes

Conversation

@hemildesai
Copy link
Contributor

@hemildesai hemildesai commented Feb 28, 2026

Summary

  • ChatDataset + THD collater support: Extend packed_sequence_thd_collater to handle non-packed (ChatDataset) data by synthesizing seq_lens, seq_lens_padded, and position_ids metadata, enabling THD format for chat-style datasets.
  • Chat template padding support: Pass the user's padding value through to apply_chat_template instead of hardcoding padding=False. Use the tokenizer's attention_mask to zero out padding positions in the loss mask — in both format_chat_template and format_prompt_completion. For answer_only_loss, tokenize the prompt without padding to get its real length, then derive the mask from the length difference against the (possibly padded) full text.
  • Chat template file path resolution: Add _resolve_chat_template to formatting_utils so chat_template can be a file path (plain text or JSON with "chat_template" key) or a literal Jinja string.
  • Remove manual EOS append: The chat template itself is responsible for EOS; remove the manual eos_token_id append after apply_chat_template.
  • Pre-padded attention_mask fix: Correct _package_tokenized_example to strip trailing pad tokens before building the attention mask, so inputs that arrive already padded get proper zero-masking.
  • HF dataset shuffle-before-slice: Shuffle HF datasets with a configurable seed before applying any split slice (e.g. train[1024:]), so that train/val splits sample from a consistent random order instead of taking contiguous chunks.
  • CP + padding_mask propagation: Add padding_mask to default_collater output and propagate padding_mask through Context Parallelism buffers (cp_utils.py).
  • PP shape mismatch fix (L2_HF_DCP): Switch collate_fn to _target_ syntax in llama3_2_1b_squad.yaml and add pad_seq_len_divisible 512 via dataloader.collate_fn in DCP PP2 test scripts so all microbatches have uniform sequence length.
  • Example config: Add qwen3_moe_30b_te_chat_thd.yaml for Qwen3 MoE 30B with ChatDataset + THD collater.
  • Comprehensive unit tests: Cover all new/changed behaviors across datasets, CP utils, formatting utilities, and content_length branches.

Test plan

  • Tests for _resolve_chat_template (file path, JSON, literal string, None, nonexistent path)
  • Tests for pre-padded attention_mask fix (TestPackageTokenizedExamplePrePaddedInput)
  • Tests for non-packed THD collation (test_utils.py)
  • Test for default_collater padding_mask output
  • Tests for format_chat_template attention_mask zeroing at padding positions
  • Tests for format_prompt_completion attention_mask zeroing at padding positions
  • Tests for CP utils padding_mask propagation
  • Tests for HF dataset shuffle-before-slice behavior
  • Tests for _package_tokenized_example content_length branches (TestContentLengthBranches)
  • Functional test assertion fixes for max_length padding boundary behavior
  • DCP PP2 tests with pad_seq_len_divisible via collate_fn + conftest override registration
  • Full unit test suite passes in affected files

🤖 Generated with Claude Code

@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 28, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@hemildesai hemildesai changed the title feat: support chat datasets in packed sequence THD collator feat: support chat datasets with THD, BSHD + CP and padding fixes Feb 28, 2026
@hemildesai
Copy link
Contributor Author

/ok to test d8d2063

@hemildesai
Copy link
Contributor Author

/ok to test 3b9b22d

hemildesai and others added 18 commits March 3, 2026 22:58
Revert num_label_tokens guards in MaskedCrossEntropy and
scale_grads_and_clip_grad_norm and their corresponding tests.
These changes will be addressed separately.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Explain why padding is disabled in the first apply_chat_template call
and deferred to _package_tokenized_example where input_ids, labels,
and attention_mask are padded consistently.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Load the full base split, shuffle with a fixed seed (42), then apply
any slice (e.g. "train[1024:]") so that train/val splits sample from
a consistent random order instead of taking contiguous chunks.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Add shuffle_seed parameter to _load_openai_messages and ChatDataset.
Defaults to 42 for backwards compatibility. Set to None to disable
shuffling entirely.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Change default shuffle_seed from 42 to None (no shuffle) so datasets
are not shuffled unless explicitly requested. Add shuffle_seed: 42
to the qwen3_moe_30b_te_chat_thd.yaml example config.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
padding_mask is collater/CP metadata — not a model input. MoE models
(gpt_oss, qwen3_moe, deepseek_v3, etc.) already derive it internally
from attention_mask when padding_mask is absent, so it is safe and
necessary (for HF models like GPT2) to always strip it here.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
GPT2LMHeadModel was the only model in the repo whose forward() did not
accept extra keyword arguments.  This caused a TypeError when the
collater includes padding_mask in the batch dict.  Adding **kwargs lets
it absorb padding_mask (and any future metadata) like every other model.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Now that GPT2LMHeadModel accepts **kwargs, every model in the repo can
handle padding_mask.  Removing the pop lets MoE models receive it
directly instead of re-deriving it from attention_mask on the GPU.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
When split is None, pass it through to load_dataset as-is instead of
overriding it to "train". Skip the slice regex match when split is None.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
format_chat_template now uses padding=False, so the pre-padded guard
in _package_tokenized_example is only needed for the
format_prompt_response_example path. Update the comment accordingly.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
The function is format_prompt_completion, not
format_prompt_response_example.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
When apply_chat_template truncates to max_length (seq_length), appending
EOS makes the total seq_length+1.  After BOS removal in
_package_tokenized_example, labels are exactly seq_length with no room
for -100 padding — the last label becomes a spurious EOS instead of -100.

Skip the EOS append when len(input_ids) >= seq_length, indicating
truncation occurred and the sequence was cut mid-conversation.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Add _resolve_chat_template to formatting_utils that accepts a file path,
JSON file (extracting the "chat_template" key), or literal Jinja string.
Wire it into ChatDataset.__init__ so users can pass a path to a template
file instead of an inline string.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
- Pass the user's padding value through to apply_chat_template instead
  of hardcoding padding=False; use the tokenizer's attention_mask to
  zero out padding positions in the loss mask.
- Remove manual EOS append after apply_chat_template; the chat template
  itself is responsible for EOS.
- In ChatDataset, resolve padding="longest" at init by batching all
  conversations through apply_chat_template(padding="longest") to find
  the max tokenized length, then use "max_length" per sample.
- Remove dead code: _JINJA_CHARS, unreachable attention_mask zeroing
  block, and _StubTokenizerChatNoGenWithPadding test stub.
- Fix _StubTokenizerChatNoGen to append EOS like real chat templates.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Let padding="longest" pass through to the collator for batch-level
padding instead of resolving it at dataset init.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Remove attention_mask handling from _make_cp_batch_for_te and
_shard_thd_chunk_for_te since THD format does not use attention_mask.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Compute content_length on the original input_ids before the [:-1]
next-token prediction shift, then subtract 1.  This ensures padded
and non-padded inputs produce identical attention masks — previously
the shift removed a pad token (padded) vs a real token (non-padded),
causing an off-by-one in the attention mask.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Remove assertions that the last real input_id is not EOS — this was
only true when EOS was manually appended after apply_chat_template.
Now apply_chat_template adds EOS via the template itself, so with
padding the shift removes a pad token and EOS stays in input_ids.

Add stronger invariant checks:
- attention_mask=0 positions must have labels=-100
- attention_mask must be contiguous (ones then zeros, right-padded)
- padding region input_ids must use pad_token_id

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
@hemildesai
Copy link
Contributor Author

/ok to test 503f142

hemildesai and others added 2 commits March 4, 2026 12:23
With seq_length=4 and truncation=True, the sequence is truncated to
just the start of the system message — no assistant tokens survive.
Skip the "must have supervised tokens" assertion when truncation is
enabled since it may legitimately remove all assistant content.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
When a chat template is used, the template's own turn-ending token
(e.g. <|im_end|>) terminates sequences — not eos_token_id
(e.g. <|endoftext|>). These are different tokens for Qwen2.5/3.
Remove assertions that eos_token_id must appear in supervised labels
since it was only true when EOS was manually appended.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
@hemildesai
Copy link
Contributor Author

/ok to test 155c8b2

@hemildesai
Copy link
Contributor Author

/ok to test bc6e83c

@hemildesai
Copy link
Contributor Author

/ok to test 6f1a31e

hemildesai and others added 3 commits March 6, 2026 09:31
Signed-off-by: hemildesai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
@hemildesai
Copy link
Contributor Author

/ok to test e8e61d5

@akoumpa
Copy link
Contributor

akoumpa commented Mar 7, 2026

/ok to test 035bc9d

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants