Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
519ccc7
feat: support chat datasets in packed sequence THD collator
hemildesai Feb 28, 2026
a6ec308
test: add missing unit tests for data and CP changes
hemildesai Feb 28, 2026
ccddc23
fix: correct attention_mask for pre-padded input_ids in _package_toke…
hemildesai Feb 28, 2026
dc57a0e
refactor: remove division-by-zero guards from this PR
hemildesai Feb 28, 2026
0bed4e4
docs: add comment explaining padding=False in format_chat_template
hemildesai Feb 28, 2026
8970999
feat: shuffle HF datasets before slicing in chat dataset loader
hemildesai Feb 28, 2026
756863c
feat: make shuffle seed configurable in ChatDataset
hemildesai Feb 28, 2026
dc53219
refactor: default shuffle_seed to None, set explicitly in yaml
hemildesai Feb 28, 2026
6966078
fix: pop padding_mask from batch before passing to model forward
hemildesai Mar 1, 2026
3d30743
fix: accept **kwargs in GPT2LMHeadModel.forward for padding_mask
hemildesai Mar 1, 2026
e561d58
fix: stop stripping padding_mask from batch before model forward
hemildesai Mar 1, 2026
47a0eae
fix: do not default split to "train" in _load_openai_messages
hemildesai Mar 2, 2026
30a8a3f
docs: fix comment about pre-padded input_ids source
hemildesai Mar 2, 2026
3da7af4
docs: fix incorrect function name in comment
hemildesai Mar 2, 2026
7b9af1f
fix: do not append EOS to truncated sequences in format_chat_template
hemildesai Mar 2, 2026
fbef5a2
feat: resolve chat_template from file path in ChatDataset
hemildesai Mar 3, 2026
1548029
fix: pass actual padding to apply_chat_template and support longest
hemildesai Mar 3, 2026
bad015c
fix: remove dataset-level longest padding resolution
hemildesai Mar 4, 2026
f032320
fix: remove attention_mask from THD context parallelism path
hemildesai Mar 4, 2026
5411a42
fix: compute attention_mask content_length before next-token shift
hemildesai Mar 4, 2026
503f142
fix: update functional tests for new padding behavior
hemildesai Mar 4, 2026
43adaf6
feat: Add native Comet ML experiment tracking support
LoganVegnaSHOP Feb 27, 2026
d6b21e5
test: Add unit tests for CometLogger (92% coverage)
LoganVegnaSHOP Feb 27, 2026
7b19d9f
fix: Require comet.project_name instead of defaulting
LoganVegnaSHOP Feb 28, 2026
1a6f66a
feat: add Qwen3 dense model handler for NeMo Automodel
LoganVegnaSHOP Mar 3, 2026
d3c6818
fix: remove **kwargs from Qwen3ForCausalLM to prevent config conflict
LoganVegnaSHOP Mar 3, 2026
497adc8
fix: resolve CI linting and test failures
LoganVegnaSHOP Mar 3, 2026
570531f
fix: skip supervised-token assertion when truncation removes assistant
hemildesai Mar 4, 2026
155c8b2
fix: remove eos_token_id assertions from pad-eos overlap tests
hemildesai Mar 4, 2026
7c217fb
Merge remote-tracking branch 'upstream/hemil/data-fixes' into feat/co…
LoganVegnaSHOP Mar 5, 2026
015fcde
native torch_fp32 norm
Mar 5, 2026
4011429
Merge remote-tracking branch 'upstream/hemil/data-fixes-fp32-norm' in…
LoganVegnaSHOP Mar 5, 2026
9b40b50
fix: attach CP attention-mask hooks for dense (non-TE) context parall…
hemildesai Mar 6, 2026
8674407
fix: use late-bound SDPA lookup for context parallelism compatibility
hemildesai Mar 6, 2026
f55f6bc
feat: add resolve_sdpa_method helper with YAML-configurable SDPA back…
hemildesai Mar 6, 2026
d4e3474
style: apply ruff formatting to changed files
hemildesai Mar 6, 2026
59deabe
fix: skip non-TE attention in MoE apply_cp instead of asserting
hemildesai Mar 6, 2026
88f0443
Merge remote-tracking branch 'upstream/hemil/cp-dense-fixes' into fea…
LoganVegnaSHOP Mar 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions examples/llm_finetune/qwen/qwen3_moe_30b_te_chat_thd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# To run this recipe, please use the following command:
# torchrun --nproc-per-node=8 examples/llm_finetune/finetune.py --config examples/llm_finetune/qwen/qwen3_moe_30b_te_chat_thd.yaml
# Adjust --nproc-per-node to the number of GPUs available on your host machine.
#
# This config uses ChatDataset with the THD collater (without sequence packing).
# The packed_sequence_thd_collater automatically synthesizes the THD metadata
# (seq_lens, seq_lens_padded, position_ids) for non-packed data, enabling
# TE context parallelism without requiring actual sequence packing.


step_scheduler:
global_batch_size: 16
local_batch_size: 2
ckpt_every_steps: 500
gc_every_steps: 10
max_steps: 1000
val_every_steps: 100

dist_env:
backend: nccl
timeout_minutes: 10

rng:
_target_: nemo_automodel.components.training.rng.StatefulRNG
seed: 1111
ranked: true

model:
_target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained
pretrained_model_name_or_path: Qwen/Qwen3-30B-A3B-Thinking-2507
backend:
_target_: nemo_automodel.components.models.common.BackendConfig
attn: te
linear: te
rms_norm: te
experts: torch_mm
dispatcher: deepep
fake_balanced_gate: false
enable_hf_state_dict_adapter: true

checkpoint:
enabled: false
checkpoint_dir: checkpoints/
model_save_format: torch_save
save_consolidated: false

distributed:
strategy: fsdp2
tp_size: 1
cp_size: 2
pp_size: 1
ep_size: 4

sequence_parallel: false
activation_checkpointing: true

pipeline:
pp_schedule: interleaved1f1b
pp_microbatch_size: 4
round_virtual_stages_to_pp_multiple: down
scale_grads_in_schedule: false
patch_inner_model: false
patch_causal_lm_model: false
layers_per_stage: 2

loss_fn:
_target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy

dataset:
_target_: nemo_automodel.components.datasets.llm.chat_dataset.ChatDataset
path_or_dataset_id: allenai/tulu-3-sft-mixture
split: train
shuffle_seed: 42
truncation: true
seq_length: 1024
padding: max_length

packed_sequence:
# No packing — the THD collater synthesizes seq_lens from ChatDataset output.
packed_sequence_size: 0

dataloader:
_target_: torchdata.stateful_dataloader.StatefulDataLoader
collate_fn: nemo_automodel.components.datasets.utils.packed_sequence_thd_collater
shuffle: true

validation_dataset:
_target_: nemo_automodel.components.datasets.llm.chat_dataset.ChatDataset
path_or_dataset_id: allenai/tulu-3-sft-mixture
split: "train[:128]"
shuffle_seed: 42
truncation: true
seq_length: 1024
padding: max_length

validation_dataloader:
_target_: torchdata.stateful_dataloader.StatefulDataLoader
collate_fn: nemo_automodel.components.datasets.utils.packed_sequence_thd_collater

optimizer:
_target_: torch.optim.Adam
betas: [0.9, 0.999]
eps: 1e-8
lr: 1.0e-5
weight_decay: 0
foreach: false

# # Uncomment and configure for W&B logging
# wandb:
# project: <your_wandb_project>
# entity: <your_wandb_entity>
# name: <your_wandb_exp_name>
# save_dir: <your_wandb_save_dir>
4 changes: 4 additions & 0 deletions nemo_automodel/_transformers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@
"Qwen2ForCausalLM",
("nemo_automodel.components.models.qwen2.model", "Qwen2ForCausalLM"),
),
(
"Qwen3ForCausalLM",
("nemo_automodel.components.models.qwen3.model", "Qwen3ForCausalLM"),
),
(
"Qwen3MoeForCausalLM",
("nemo_automodel.components.models.qwen3_moe.model", "Qwen3MoeForCausalLM"),
Expand Down
9 changes: 6 additions & 3 deletions nemo_automodel/components/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from typing import Any, Callable

import torch
Expand Down Expand Up @@ -48,13 +47,17 @@ def initialize_attn_module_and_func(
attn_func = attn_module.__call__
return attn_module, attn_func
elif attn_impl == "sdpa":
attn_func = functools.partial(
F.scaled_dot_product_attention,
defaults = dict(
scale=softmax_scale,
is_causal=attn_mask_type == "causal",
enable_gqa=num_gqa_groups is not None,
**kwargs,
)

def attn_func(*args, **call_kwargs):
merged = {**defaults, **call_kwargs}
return F.scaled_dot_product_attention(*args, **merged)

return None, attn_func
elif attn_impl == "flex":
attn_module = FlexAttention()
Expand Down
54 changes: 46 additions & 8 deletions nemo_automodel/components/datasets/llm/chat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import json
import re
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union

Expand All @@ -24,6 +25,7 @@
from nemo_automodel.components.datasets.llm.formatting_utils import (
_add_pad_token,
_has_chat_template,
_resolve_chat_template,
format_chat_template,
)

Expand All @@ -46,24 +48,60 @@ def _as_iter(val: Union[str, Sequence[str]]) -> Iterator[str]:
yield x


_SPLIT_SLICE_RE = re.compile(r"^(\w+)\[(\d*):(\d*)\]$")


def _load_openai_messages(
path_or_dataset_id: Union[str, Sequence[str]], split: Optional[str] = None, name: Optional[str] = None
path_or_dataset_id: Union[str, Sequence[str]],
split: Optional[str] = None,
name: Optional[str] = None,
shuffle_seed: Optional[int] = None,
):
"""Load OpenAI chat messages datasets from HF or local JSON/JSONL files.

For HF repo IDs, we delegate to datasets.load_dataset.
For HF repo IDs, we delegate to datasets.load_dataset. When *split*
is provided, the full base split is loaded and shuffled *before* any
slice (e.g. ``[1024:]``) is applied so that train/val splits sample
from a consistent random order. When *split* is ``None`` it is passed
through to ``load_dataset`` as-is (no default override).

For local files, we manually parse JSONL/JSON to avoid pyarrow type
inference issues (e.g., heterogeneous field types under `tools`).

Args:
path_or_dataset_id: HF dataset ID or local file path(s).
split: Dataset split to load (e.g., "train", "validation").
split: Dataset split to load (e.g., "train", "train[1024:]").
name: Dataset configuration/subset name
shuffle_seed: Random seed for shuffling HF datasets before slicing.
Set to ``None`` to disable shuffling.
"""
if isinstance(path_or_dataset_id, str) and _is_hf_repo_id(path_or_dataset_id):
return load_dataset(
path_or_dataset_id, name=name, split=split, streaming=False, verification_mode=VerificationMode.NO_CHECKS
# Parse split string: "train[1024:]" -> base="train", slice(1024, None)
base_split = split
sl = None
if split is not None:
match = _SPLIT_SLICE_RE.match(split)
if match:
base_split = match.group(1)
start = int(match.group(2)) if match.group(2) else None
end = int(match.group(3)) if match.group(3) else None
sl = slice(start, end)

dataset = load_dataset(
path_or_dataset_id,
name=name,
split=base_split,
streaming=False,
verification_mode=VerificationMode.NO_CHECKS,
)
if shuffle_seed is not None:
dataset = dataset.shuffle(seed=shuffle_seed)

if sl is not None:
indices = range(*sl.indices(len(dataset)))
dataset = dataset.select(indices)

return dataset

files = list(_as_iter(path_or_dataset_id))
if not files:
Expand Down Expand Up @@ -137,14 +175,14 @@ def __init__(
truncation: Union[str, bool] = "do_not_truncate",
start_of_turn_token: Optional[str] = None,
chat_template: Optional[str] = None,
shuffle_seed: Optional[int] = None,
) -> None:
if tokenizer is None:
raise ValueError("Tokenizer is required")

# Enforce chat-template availability for tool-calling data
if chat_template is not None:
# Allow overriding the tokenizer's template
tokenizer.chat_template = chat_template
tokenizer.chat_template = _resolve_chat_template(chat_template)

if not _has_chat_template(tokenizer):
raise ValueError("ChatDataset requires a tokenizer with chat template support.")
Expand All @@ -155,7 +193,7 @@ def __init__(
self.truncation = truncation
self.start_of_turn_token = start_of_turn_token

self.dataset = _load_openai_messages(path_or_dataset_id, split=split, name=name)
self.dataset = _load_openai_messages(path_or_dataset_id, split=split, name=name, shuffle_seed=shuffle_seed)

# Ensure pad token presence for downstream padding
eos_token_id = getattr(self.tokenizer, "eos_token_id", 0)
Expand Down
Loading
Loading