diff --git a/examples/llm_finetune/qwen/qwen3_moe_30b_te_chat_thd.yaml b/examples/llm_finetune/qwen/qwen3_moe_30b_te_chat_thd.yaml new file mode 100644 index 000000000..19c7ab9f0 --- /dev/null +++ b/examples/llm_finetune/qwen/qwen3_moe_30b_te_chat_thd.yaml @@ -0,0 +1,128 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# To run this recipe, please use the following command: +# torchrun --nproc-per-node=8 examples/llm_finetune/finetune.py --config examples/llm_finetune/qwen/qwen3_moe_30b_te_chat_thd.yaml +# Adjust --nproc-per-node to the number of GPUs available on your host machine. +# +# This config uses ChatDataset with the THD collater (without sequence packing). +# The packed_sequence_thd_collater automatically synthesizes the THD metadata +# (seq_lens, seq_lens_padded, position_ids) for non-packed data, enabling +# TE context parallelism without requiring actual sequence packing. + + +step_scheduler: + global_batch_size: 16 + local_batch_size: 2 + ckpt_every_steps: 500 + gc_every_steps: 10 + max_steps: 1000 + val_every_steps: 100 + +dist_env: + backend: nccl + timeout_minutes: 10 + +rng: + _target_: nemo_automodel.components.training.rng.StatefulRNG + seed: 1111 + ranked: true + +model: + _target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained + pretrained_model_name_or_path: Qwen/Qwen3-30B-A3B-Thinking-2507 + backend: + _target_: nemo_automodel.components.models.common.BackendConfig + attn: te + linear: te + rms_norm: te + experts: torch_mm + dispatcher: deepep + fake_balanced_gate: false + enable_hf_state_dict_adapter: true + +checkpoint: + enabled: false + checkpoint_dir: checkpoints/ + model_save_format: torch_save + save_consolidated: false + +distributed: + strategy: fsdp2 + tp_size: 1 + cp_size: 2 + pp_size: 1 + ep_size: 4 + + sequence_parallel: false + activation_checkpointing: true + + pipeline: + pp_schedule: interleaved1f1b + pp_microbatch_size: 4 + round_virtual_stages_to_pp_multiple: down + scale_grads_in_schedule: false + patch_inner_model: false + patch_causal_lm_model: false + layers_per_stage: 2 + +loss_fn: + _target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy + +dataset: + _target_: nemo_automodel.components.datasets.llm.chat_dataset.ChatDataset + path_or_dataset_id: allenai/tulu-3-sft-mixture + split: train + shuffle_seed: 42 + truncation: true + seq_length: 1024 + padding: max_length + +packed_sequence: + # No packing — the THD collater synthesizes seq_lens from ChatDataset output. + packed_sequence_size: 0 + +dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + collate_fn: nemo_automodel.components.datasets.utils.packed_sequence_thd_collater + shuffle: true + +validation_dataset: + _target_: nemo_automodel.components.datasets.llm.chat_dataset.ChatDataset + path_or_dataset_id: allenai/tulu-3-sft-mixture + split: "train[:128]" + shuffle_seed: 42 + truncation: true + seq_length: 1024 + padding: max_length + +validation_dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + collate_fn: nemo_automodel.components.datasets.utils.packed_sequence_thd_collater + +optimizer: + _target_: torch.optim.Adam + betas: [0.9, 0.999] + eps: 1e-8 + lr: 1.0e-5 + weight_decay: 0 + foreach: false + +# # Uncomment and configure for W&B logging +# wandb: +# project: +# entity: +# name: +# save_dir: diff --git a/nemo_automodel/_transformers/registry.py b/nemo_automodel/_transformers/registry.py index 83310bd11..ffc4428f5 100644 --- a/nemo_automodel/_transformers/registry.py +++ b/nemo_automodel/_transformers/registry.py @@ -92,6 +92,10 @@ "Qwen2ForCausalLM", ("nemo_automodel.components.models.qwen2.model", "Qwen2ForCausalLM"), ), + ( + "Qwen3ForCausalLM", + ("nemo_automodel.components.models.qwen3.model", "Qwen3ForCausalLM"), + ), ( "Qwen3MoeForCausalLM", ("nemo_automodel.components.models.qwen3_moe.model", "Qwen3MoeForCausalLM"), diff --git a/nemo_automodel/components/attention/utils.py b/nemo_automodel/components/attention/utils.py index 8d5158362..6b4ea98d2 100644 --- a/nemo_automodel/components/attention/utils.py +++ b/nemo_automodel/components/attention/utils.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools from typing import Any, Callable import torch @@ -48,13 +47,17 @@ def initialize_attn_module_and_func( attn_func = attn_module.__call__ return attn_module, attn_func elif attn_impl == "sdpa": - attn_func = functools.partial( - F.scaled_dot_product_attention, + defaults = dict( scale=softmax_scale, is_causal=attn_mask_type == "causal", enable_gqa=num_gqa_groups is not None, **kwargs, ) + + def attn_func(*args, **call_kwargs): + merged = {**defaults, **call_kwargs} + return F.scaled_dot_product_attention(*args, **merged) + return None, attn_func elif attn_impl == "flex": attn_module = FlexAttention() diff --git a/nemo_automodel/components/datasets/llm/chat_dataset.py b/nemo_automodel/components/datasets/llm/chat_dataset.py index 51f36e8ae..b16dd0acf 100644 --- a/nemo_automodel/components/datasets/llm/chat_dataset.py +++ b/nemo_automodel/components/datasets/llm/chat_dataset.py @@ -15,6 +15,7 @@ from __future__ import annotations import json +import re from pathlib import Path from typing import Any, Dict, Iterator, List, Optional, Sequence, Union @@ -24,6 +25,7 @@ from nemo_automodel.components.datasets.llm.formatting_utils import ( _add_pad_token, _has_chat_template, + _resolve_chat_template, format_chat_template, ) @@ -46,24 +48,60 @@ def _as_iter(val: Union[str, Sequence[str]]) -> Iterator[str]: yield x +_SPLIT_SLICE_RE = re.compile(r"^(\w+)\[(\d*):(\d*)\]$") + + def _load_openai_messages( - path_or_dataset_id: Union[str, Sequence[str]], split: Optional[str] = None, name: Optional[str] = None + path_or_dataset_id: Union[str, Sequence[str]], + split: Optional[str] = None, + name: Optional[str] = None, + shuffle_seed: Optional[int] = None, ): """Load OpenAI chat messages datasets from HF or local JSON/JSONL files. - For HF repo IDs, we delegate to datasets.load_dataset. + For HF repo IDs, we delegate to datasets.load_dataset. When *split* + is provided, the full base split is loaded and shuffled *before* any + slice (e.g. ``[1024:]``) is applied so that train/val splits sample + from a consistent random order. When *split* is ``None`` it is passed + through to ``load_dataset`` as-is (no default override). + For local files, we manually parse JSONL/JSON to avoid pyarrow type inference issues (e.g., heterogeneous field types under `tools`). Args: path_or_dataset_id: HF dataset ID or local file path(s). - split: Dataset split to load (e.g., "train", "validation"). + split: Dataset split to load (e.g., "train", "train[1024:]"). name: Dataset configuration/subset name + shuffle_seed: Random seed for shuffling HF datasets before slicing. + Set to ``None`` to disable shuffling. """ if isinstance(path_or_dataset_id, str) and _is_hf_repo_id(path_or_dataset_id): - return load_dataset( - path_or_dataset_id, name=name, split=split, streaming=False, verification_mode=VerificationMode.NO_CHECKS + # Parse split string: "train[1024:]" -> base="train", slice(1024, None) + base_split = split + sl = None + if split is not None: + match = _SPLIT_SLICE_RE.match(split) + if match: + base_split = match.group(1) + start = int(match.group(2)) if match.group(2) else None + end = int(match.group(3)) if match.group(3) else None + sl = slice(start, end) + + dataset = load_dataset( + path_or_dataset_id, + name=name, + split=base_split, + streaming=False, + verification_mode=VerificationMode.NO_CHECKS, ) + if shuffle_seed is not None: + dataset = dataset.shuffle(seed=shuffle_seed) + + if sl is not None: + indices = range(*sl.indices(len(dataset))) + dataset = dataset.select(indices) + + return dataset files = list(_as_iter(path_or_dataset_id)) if not files: @@ -137,14 +175,14 @@ def __init__( truncation: Union[str, bool] = "do_not_truncate", start_of_turn_token: Optional[str] = None, chat_template: Optional[str] = None, + shuffle_seed: Optional[int] = None, ) -> None: if tokenizer is None: raise ValueError("Tokenizer is required") # Enforce chat-template availability for tool-calling data if chat_template is not None: - # Allow overriding the tokenizer's template - tokenizer.chat_template = chat_template + tokenizer.chat_template = _resolve_chat_template(chat_template) if not _has_chat_template(tokenizer): raise ValueError("ChatDataset requires a tokenizer with chat template support.") @@ -155,7 +193,7 @@ def __init__( self.truncation = truncation self.start_of_turn_token = start_of_turn_token - self.dataset = _load_openai_messages(path_or_dataset_id, split=split, name=name) + self.dataset = _load_openai_messages(path_or_dataset_id, split=split, name=name, shuffle_seed=shuffle_seed) # Ensure pad token presence for downstream padding eos_token_id = getattr(self.tokenizer, "eos_token_id", 0) diff --git a/nemo_automodel/components/datasets/llm/formatting_utils.py b/nemo_automodel/components/datasets/llm/formatting_utils.py index bf1f948e7..50dfc2e03 100644 --- a/nemo_automodel/components/datasets/llm/formatting_utils.py +++ b/nemo_automodel/components/datasets/llm/formatting_utils.py @@ -12,14 +12,45 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging import re +from pathlib import Path from typing import TYPE_CHECKING, Dict, List, Optional, Union import torch logger = logging.getLogger(__name__) + +def _resolve_chat_template(chat_template: Optional[str]) -> Optional[str]: + """Resolve a chat template string that may be a file path. + + If *chat_template* points to an existing file, its contents are returned. + If opening it as a file fails and the string contains Jinja-like characters + (``{``, ``}``, or newlines) it is treated as a literal template. Otherwise + a :class:`ValueError` is raised so the caller knows the path was invalid. + + Args: + chat_template: A Jinja template string or path to a template file. + + Returns: + The resolved template string, or ``None`` when the input is ``None``. + """ + if chat_template is None: + return None + + p = Path(chat_template) + if p.exists(): + content = p.read_text(encoding="utf-8") + try: + content = json.loads(content)["chat_template"] + except (json.JSONDecodeError, KeyError, TypeError): + pass + return content + return chat_template + + if TYPE_CHECKING: from transformers import PreTrainedTokenizer @@ -151,9 +182,24 @@ def _package_tokenized_example( A dictionary with input_ids, labels, and attention_mask. """ labels = input_ids.copy() + # Compute content length on the original input_ids (before the next-token + # shift) so that pre-padded and non-padded inputs produce identical + # attention masks. The shift removes one token; when the input is padded + # that token is a pad, but when unpadded it is the last real token. + # Computing on the original and subtracting 1 gives the same result in + # both cases. + content_length = len(input_ids) + if pad_token_id is not None and content_length > 0: + end = content_length + while end > 0 and input_ids[end - 1] == pad_token_id: + end -= 1 + if pad_token_id == eos_token_id: + content_length = min(end + 1, content_length) + else: + content_length = end input_ids = input_ids[:-1] - # input_ids= [a, b] -> attention_mask = [1, 1] - attention_mask = [1] * len(input_ids) + content_length = max(0, min(content_length - 1, len(input_ids))) + attention_mask = [1] * content_length + [0] * (len(input_ids) - content_length) # Labels: mask out prompt tokens labels[:] = [label if bool(m) else -100 for label, m in zip(labels, assistant_masks)] # remove BOS @@ -163,7 +209,10 @@ def _package_tokenized_example( assert input_ids[-1] != eos_token_id, f"input_ids[-1]={input_ids[-1]} == eos_token_id={eos_token_id}" assert len(input_ids) == len(labels), f"len(input_ids)={len(input_ids)} != len(labels)={len(labels)}" - if isinstance(seq_length, int) and padding not in [None, "do_not_pad", False]: + # Only pad to a fixed length for "max_length". For "longest" / True the + # collator pads to the longest sample in the batch, so the dataset must + # return variable-length sequences (same as "do_not_pad"). + if isinstance(seq_length, int) and padding in ("max_length",): input_ids = _pad_to_seq_length(input_ids, pad_token_id, seq_length) labels = _pad_to_seq_length(labels, -100, seq_length) @@ -283,13 +332,12 @@ def format_chat_template( max_length=seq_length, ) - # Choose the last conversation as answer other history are context by finding the last masked token - # which indicates end of context and beginning of answer input_ids = tokenized_chat.get("input_ids") if template_has_generation_kwd: mask = tokenized_chat["assistant_masks"] elif not template_has_generation_kwd and answer_only_loss_mask: - # in this case we need to manually split up the formatted_text. Only the final assistant turn should be considered as answer. + # Tokenize prompt-only without padding to get its real length, + # then derive the mask from the length difference. answer_text = formatted_text.pop() assert answer_text["role"] == "assistant", "The last message in the formatted_text must be an assistant message" tokenized_prompt = tokenizer.apply_chat_template( @@ -298,7 +346,7 @@ def format_chat_template( tokenize=True, return_dict=True, return_assistant_tokens_mask=template_has_generation_kwd, - padding=padding, + padding=False, truncation=truncation, max_length=seq_length, ) @@ -307,9 +355,13 @@ def format_chat_template( else: mask = [1] * len(input_ids) - if getattr(tokenizer, "eos_token_id", None) and input_ids[-1] != tokenizer.eos_token_id: - input_ids += [tokenizer.eos_token_id] - mask += [1] + # Zero out the loss mask at padding positions using the tokenizer's + # own attention_mask so pad tokens are never treated as supervised. + tokenizer_attn_mask = tokenized_chat.get("attention_mask") + if tokenizer_attn_mask is not None: + for i in range(min(len(mask), len(tokenizer_attn_mask))): + if not tokenizer_attn_mask[i]: + mask[i] = 0 return _package_tokenized_example( tokenizer=tokenizer, diff --git a/nemo_automodel/components/datasets/utils.py b/nemo_automodel/components/datasets/utils.py index 3042f8e8b..99a5c0b05 100644 --- a/nemo_automodel/components/datasets/utils.py +++ b/nemo_automodel/components/datasets/utils.py @@ -243,7 +243,14 @@ def default_collater(batch, pad_seq_len_divisible=None): } # convert to tensors - return {k: batchify(torch.LongTensor(v)) for k, v in ans.items()} + result = {k: batchify(torch.LongTensor(v)) for k, v in ans.items()} + + # Add padding_mask similar to cp_utils.py + if "input_ids" in result: + input_ids_pad_token = get_pad_token_from_key("input_ids", pad_token_ids) or 0 + result["padding_mask"] = (result["input_ids"] == input_ids_pad_token).bool() + + return result def packed_sequence_thd_collater(batch): @@ -265,20 +272,28 @@ def packed_sequence_thd_collater(batch): - Padding and stacking seq_lens and seq_lens_padded with sentinel value -1000 - Including 'qkv_format': 'thd' in the output to indicate THD format - IMPORTANT: All examples in the batch must have the same token sequence length for input_ids, - labels, and position_ids. This is typically ensured by the dataset/packing logic that creates - fixed-length packed sequences. + When batch items lack packed-sequence metadata (seq_lens, seq_lens_padded, position_ids), + such as samples from ChatDataset, this collater synthesizes the missing fields so that each + sample is treated as a single-sequence "pack". Variable-length sequences are padded to the + longest length in the batch. This enables using THD format with TE context parallelism + without requiring the dataset to perform actual sequence packing. Args: - batch (List[dict]): A list of dictionaries, where each dictionary represents one packed example. - Each dictionary should contain: + batch (List[dict]): A list of dictionaries, where each dictionary represents one example. + + For pre-packed data, each dictionary should contain: - 'input_ids': List[int] - Token IDs for all packed sequences (must be same length across batch) - 'labels': List[int] - Labels for all packed sequences (must be same length across batch) - 'position_ids': List[int] - Position IDs for all tokens (must be same length across batch) - 'seq_lens': List[int] - Actual sequence lengths for each packed sequence - 'seq_lens_padded': List[int] - Sequence lengths including identifier/padding tokens - Example batch with 2 examples, both with 6 total tokens: + For non-packed data (e.g. ChatDataset), each dictionary needs only: + - 'input_ids': List[int] - Token IDs (variable length across batch) + - 'labels': List[int] - Labels (same length as input_ids) + - 'attention_mask': List[int] - (optional) 1 for real tokens, 0 for padding + + Example batch with 2 packed examples, both with 6 total tokens: [ { 'input_ids': [1, 2, 3, 99, 4, 5], # Two sequences: [1,2,3] and [4,5] with sep token 99 @@ -308,15 +323,42 @@ def packed_sequence_thd_collater(batch): Note: seq_lens and seq_lens_padded are padded with -1000 to handle variable number of packed sequences per example. These sentinel values should be filtered out before use. """ - # Remove padding token IDs if present (not used in passthrough) + # Extract and remove padding token metadata if present + pad_token_ids = None if len(batch) > 0 and "___PAD_TOKEN_IDS___" in batch[0]: + pad_token_ids = batch[0].get("___PAD_TOKEN_IDS___") for item in batch: item.pop("___PAD_TOKEN_IDS___", None) - # Extract all keys from the first batch item if len(batch) == 0: return {} + # If batch items lack packed-sequence metadata (e.g. from ChatDataset), + # synthesize seq_lens, seq_lens_padded, and position_ids so that each + # sample is treated as a single-sequence "pack". + if "seq_lens" not in batch[0]: + input_ids_pad = get_pad_token_from_key("input_ids", pad_token_ids) or 0 + max_len = max(len(item["input_ids"]) for item in batch) + + for item in batch: + cur_len = len(item["input_ids"]) + if "attention_mask" in item: + actual_len = sum(item["attention_mask"]) + item.pop("attention_mask") + else: + actual_len = cur_len + + pad_amount = max_len - cur_len + item["seq_lens"] = [actual_len] + # seq_lens_padded must cover the full padded length so that + # cu_seqlens_padded[-1] == total_tokens in the downstream THD pipeline. + item["seq_lens_padded"] = [max_len] + item["position_ids"] = list(range(max_len)) + + if pad_amount > 0: + item["input_ids"] = list(item["input_ids"]) + [input_ids_pad] * pad_amount + item["labels"] = list(item["labels"]) + [-100] * pad_amount + tokens = batchify(torch.stack([torch.tensor(x["input_ids"]) for x in batch])) labels = batchify(torch.stack([torch.tensor(x["labels"]) for x in batch])) position_ids = batchify(torch.stack([torch.tensor(x["position_ids"]) for x in batch])) diff --git a/nemo_automodel/components/distributed/cp_utils.py b/nemo_automodel/components/distributed/cp_utils.py index c12ff9f39..04e735458 100644 --- a/nemo_automodel/components/distributed/cp_utils.py +++ b/nemo_automodel/components/distributed/cp_utils.py @@ -101,6 +101,29 @@ def create_context_parallel_ctx( ) +def attach_context_parallel_hooks(model: torch.nn.Module): + """Attach forward pre-hooks to self_attn modules to fix attention masks for context parallelism. + + Context parallelism shards Q/K/V on the sequence dimension as DTensors, + so explicit 4D attention masks would have mismatched shapes. This function + registers a hook on every ``self_attn`` sub-module that strips the + ``attention_mask`` kwarg and sets ``is_causal=True`` instead, letting + SDPA handle causal masking internally. + + Based on ``accelerate.big_modeling._attach_context_parallel_hooks``. + """ + + def _self_attn_pre_forward_hook(_module, module_args, module_kwargs): + if "attention_mask" in module_kwargs: + module_kwargs["attention_mask"] = None + module_kwargs["is_causal"] = True + return module_args, module_kwargs + + for name, module in model.named_modules(): + if name.endswith("self_attn"): + module.register_forward_pre_hook(_self_attn_pre_forward_hook, with_kwargs=True, prepend=True) + + def make_cp_batch_and_ctx( device_mesh, batch, @@ -152,7 +175,11 @@ def _get_mesh_size(mesh): if _get_mesh_size(cp_mesh) <= 1: return nullcontext, batch - # CP doesn't support packed sequence currently. Let torch SDPA handle attention mask. + # Remove attention_mask from the batch so the model does not attempt to + # build a 4D causal mask (which would have mismatched shapes with + # DTensor-sharded Q/K/V). Each self_attn module's forward_pre_hook + # (registered by attach_context_parallel_hooks) will set is_causal=True + # so that SDPA handles causal masking internally. batch.pop("attention_mask", None) if "position_ids" not in batch and (_get_mesh_size(cp_mesh) > 1 or _get_mesh_size(tp_mesh) > 1): @@ -160,16 +187,25 @@ def _get_mesh_size(mesh): input_ids = batch["input_ids"] position_ids = batch["position_ids"] - labels = batch["labels"] + + # Collect all available tensors for context parallel + cp_buffers = [input_ids, labels, position_ids] + cp_seq_dims = [1, 1, 1] + cp_no_restore_buffers = {input_ids, labels} + + # Add loss_mask if available if loss_mask is not None: - cp_buffers = [input_ids, labels, position_ids, loss_mask] - cp_seq_dims = [1, 1, 1, 1] - cp_no_restore_buffers = {input_ids, labels, loss_mask} - else: - cp_buffers = [input_ids, labels, position_ids] - cp_seq_dims = [1, 1, 1] - cp_no_restore_buffers = {input_ids, labels} + cp_buffers.append(loss_mask) + cp_seq_dims.append(1) + cp_no_restore_buffers.add(loss_mask) + + # Add padding_mask if available in batch + if "padding_mask" in batch: + padding_mask = batch["padding_mask"] + cp_buffers.append(padding_mask) + cp_seq_dims.append(1) + cp_no_restore_buffers.add(padding_mask) cp_ctx = create_context_parallel_ctx( cp_mesh=cp_mesh, @@ -280,7 +316,7 @@ def make_cp_batch_for_te( _shard_thd_chunk_for_te(chunk_batch, cp_mesh, qkv_format, seq_lens_padding_value, padding_token_id) ) - return { + return_dict = { "input_ids": torch.stack([chunk["input_ids"] for chunk in chunks]), "labels": torch.stack([chunk["labels"] for chunk in chunks]), "position_ids": torch.stack([chunk["position_ids"] for chunk in chunks]), @@ -292,6 +328,8 @@ def make_cp_batch_for_te( "cp_rank": torch.distributed.get_rank(group=cp_mesh.get_group()) if cp_mesh is not None else 0, } + return return_dict + def _shard_thd_chunk_for_te( batch, @@ -316,11 +354,16 @@ def _shard_thd_chunk_for_te( cp_size = cp_mesh.size() cp_rank = torch.distributed.get_rank(group=cp_mesh.get_group()) if cp_mesh is not None else 0 - for key in ["input_ids", "labels", "position_ids", "padding_mask"]: - val = batch[key] - index = tex.thd_get_partitioned_indices(filtered_cu_seqlens_padded, val.size(0), cp_size, cp_rank) - val = val.index_select(0, index) - batch[key] = val + + # Handle all mask keys that may be present in the batch + mask_keys = ["input_ids", "labels", "position_ids", "padding_mask"] + + for key in mask_keys: + if key in batch: + val = batch[key] + index = tex.thd_get_partitioned_indices(filtered_cu_seqlens_padded, val.size(0), cp_size, cp_rank) + val = val.index_select(0, index) + batch[key] = val max_seqlen = (filtered_cu_seqlens_padded[1:] - filtered_cu_seqlens_padded[:-1]).max().item() output_batch = { @@ -334,4 +377,5 @@ def _shard_thd_chunk_for_te( "cp_size": cp_size, "cp_rank": cp_rank, } + return output_batch diff --git a/nemo_automodel/components/loggers/comet_utils.py b/nemo_automodel/components/loggers/comet_utils.py new file mode 100644 index 000000000..79e739a13 --- /dev/null +++ b/nemo_automodel/components/loggers/comet_utils.py @@ -0,0 +1,159 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Dict, Optional + +import torch +import torch.distributed as dist + +logger = logging.getLogger(__name__) + + +class CometLogger: + """ + Comet ML logger for experiment tracking. + """ + + def __init__( + self, + project_name: str, + workspace: Optional[str] = None, + api_key: Optional[str] = None, + experiment_name: Optional[str] = None, + tags: Optional[list] = None, + auto_metric_logging: bool = False, + **kwargs, + ): + """ + Initialize Comet ML logger. + + Args: + project_name: Name of the Comet project + workspace: Comet workspace (optional, uses default from config/env) + api_key: Comet API key (optional, uses COMET_API_KEY env var) + experiment_name: Name for this experiment run (optional) + tags: List of tags to add to the experiment + auto_metric_logging: Whether to enable Comet's auto metric logging + **kwargs: Additional arguments passed to comet_ml.Experiment() + """ + try: + import comet_ml + except ImportError: + raise ImportError("comet_ml is not installed. Please install it with: uv add comet_ml") + + self.comet_ml = comet_ml + self.experiment = None + + if dist.is_initialized() and dist.get_rank() == 0: + init_kwargs = {"project_name": project_name, "auto_metric_logging": auto_metric_logging, **kwargs} + if api_key: + init_kwargs["api_key"] = api_key + if workspace: + init_kwargs["workspace"] = workspace + + self.experiment = comet_ml.Experiment(**init_kwargs) + + if experiment_name: + self.experiment.set_name(experiment_name) + if tags: + self.experiment.add_tags(tags) + + logger.info(f"Comet experiment: {self.experiment.url}") + + def log_params(self, params: Dict[str, Any]) -> None: + """Log parameters to Comet. + + Args: + params: Dictionary of parameters to log + """ + if not dist.get_rank() == 0 or self.experiment is None: + return + + self.experiment.log_parameters(params) + + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + """Log metrics to Comet. + + Args: + metrics: Dictionary of metrics to log + step: Step number for the metrics (optional) + """ + if not dist.get_rank() == 0 or self.experiment is None: + return + + try: + float_metrics = {} + for key, value in metrics.items(): + if isinstance(value, torch.Tensor): + float_metrics[key] = value.item() if value.numel() == 1 else float(value.mean().item()) + elif isinstance(value, (int, float)): + float_metrics[key] = float(value) + else: + logger.warning(f"Skipping metric {key} with unsupported type: {type(value)}") + + self.experiment.log_metrics(float_metrics, step=step) + except Exception as e: + logger.warning(f"Failed to log metrics: {e}") + + def end(self) -> None: + """End the Comet experiment.""" + if self.experiment is not None: + self.experiment.end() + logger.info("Comet experiment ended successfully") + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end() + + +def build_comet(cfg) -> CometLogger: + """Build Comet logger from configuration. + + Args: + cfg: Configuration object containing Comet settings + + Returns: + CometLogger instance + """ + comet_config = cfg.get("comet", {}) + if not comet_config: + raise ValueError("Comet configuration not found in config") + + project_name = comet_config.get("project_name", None) + if not project_name: + raise ValueError("comet.project_name is required") + + workspace = comet_config.get("workspace", None) + api_key = comet_config.get("api_key", None) + experiment_name = comet_config.get("experiment_name", "") + tags = list(comet_config.get("tags", [])) + auto_metric_logging = comet_config.get("auto_metric_logging", False) + + if hasattr(cfg, "model") and hasattr(cfg.model, "pretrained_model_name_or_path"): + tags.append(f"model:{cfg.model.pretrained_model_name_or_path}") + + if not experiment_name and hasattr(cfg, "model") and hasattr(cfg.model, "pretrained_model_name_or_path"): + experiment_name = "_".join(cfg.model.pretrained_model_name_or_path.split("/")[-2:]) + + return CometLogger( + project_name=project_name, + workspace=workspace, + api_key=api_key, + experiment_name=experiment_name, + tags=tags, + auto_metric_logging=auto_metric_logging, + ) diff --git a/nemo_automodel/components/models/common/utils.py b/nemo_automodel/components/models/common/utils.py index 28eca2939..9de269af6 100644 --- a/nemo_automodel/components/models/common/utils.py +++ b/nemo_automodel/components/models/common/utils.py @@ -224,6 +224,29 @@ def __post_init__(self): ) +class TorchFp32RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + TorchFp32RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + def reset_parameters(self): + nn.init.zeros_(self.weight) + + def initialize_rms_norm_module( rms_norm_impl: str, dim: int, @@ -258,9 +281,7 @@ def initialize_rms_norm_module( return nn.RMSNorm(dim, eps=eps, device=device, dtype=dtype) elif rms_norm_impl == "torch_fp32": # LlamaRMSNorm reference: generic fp32-upcast implementation for accuracy matching - from transformers.models.llama.modeling_llama import LlamaRMSNorm as Float32RMSNorm - - return Float32RMSNorm(dim, eps=eps).to(device=device, dtype=dtype) + return TorchFp32RMSNorm(dim, eps=eps).to(device=device, dtype=dtype) else: raise ValueError(f"Unsupported RMSNorm implementation: {rms_norm_impl}") diff --git a/nemo_automodel/components/models/gpt2.py b/nemo_automodel/components/models/gpt2.py index be6fb9e08..cb2747707 100644 --- a/nemo_automodel/components/models/gpt2.py +++ b/nemo_automodel/components/models/gpt2.py @@ -167,7 +167,7 @@ def __init__( def initialize_weights(self): self._init_weights() - def forward(self, input_ids: torch.LongTensor) -> torch.Tensor: # (B, T) → (B, T, V) + def forward(self, input_ids: torch.LongTensor, **kwargs) -> torch.Tensor: # (B, T) → (B, T, V) batch_size, seq_len = input_ids.shape if seq_len > self.wpe.num_embeddings: diff --git a/nemo_automodel/components/models/qwen3/__init__.py b/nemo_automodel/components/models/qwen3/__init__.py new file mode 100644 index 000000000..070b8c0d7 --- /dev/null +++ b/nemo_automodel/components/models/qwen3/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo_automodel/components/models/qwen3/layers.py b/nemo_automodel/components/models/qwen3/layers.py new file mode 100644 index 000000000..d4704541b --- /dev/null +++ b/nemo_automodel/components/models/qwen3/layers.py @@ -0,0 +1,134 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch +from torch import nn + +from nemo_automodel.components.attention.utils import ( + initialize_attn_module_and_func, + postprocess_output_for_attn, + preprocess_args_and_kwargs_for_attn, +) +from nemo_automodel.components.models.common import ( + BackendConfig, + initialize_linear_module, + initialize_rms_norm_module, +) +from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb_qk + + +class Qwen3Attention(nn.Module): + """Qwen3 dense attention with per-head QK RMSNorm and RoPE. + + Identical to the Qwen3 MoE attention layer — the attention mechanism + is shared between dense and MoE variants. + """ + + def __init__(self, config, backend: BackendConfig): + super().__init__() + self.backend = backend + + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads) + + attention_bias = getattr(config, "attention_bias", False) + + self.q_proj = initialize_linear_module( + backend.linear, config.hidden_size, self.num_heads * self.head_dim, attention_bias + ) + self.k_proj = initialize_linear_module( + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, attention_bias + ) + self.v_proj = initialize_linear_module( + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, attention_bias + ) + self.o_proj = initialize_linear_module( + backend.linear, self.num_heads * self.head_dim, config.hidden_size, attention_bias + ) + + self.q_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps) + self.k_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps) + + softmax_scale = self.head_dim**-0.5 + self.attn_module, self.attn_func = initialize_attn_module_and_func( + attn_impl=backend.attn, + num_attention_heads=self.num_heads, + num_qk_channels=self.head_dim, + num_v_channels=self.head_dim, + softmax_scale=softmax_scale, + num_gqa_groups=self.num_kv_heads, + ) + + def forward( + self, + x: torch.Tensor, + *, + freqs_cis: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **attn_kwargs: Any, + ) -> torch.Tensor: + if len(x.shape) == 2: + qkv_format = "thd" + num_tokens = x.shape[0] + else: + qkv_format = "bshd" + bsz, seqlen, _ = x.size() + + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + + if qkv_format == "thd": + q = q.view(num_tokens, self.num_heads, self.head_dim) + k = k.view(num_tokens, self.num_kv_heads, self.head_dim) + v = v.view(num_tokens, self.num_kv_heads, self.head_dim) + else: + q = q.view(bsz, seqlen, self.num_heads, self.head_dim) + k = k.view(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = v.view(bsz, seqlen, self.num_kv_heads, self.head_dim) + + q = self.q_norm(q) + k = self.k_norm(k) + + q, k = apply_rotary_emb_qk( + q, + k, + freqs_cis, + format=qkv_format, + rope_fusion=self.backend.rope_fusion, + cu_seqlens=attn_kwargs.get("cu_seqlens", None), + cp_size=attn_kwargs.get("cp_size", 1), + cp_rank=attn_kwargs.get("cp_rank", 0), + ) + + q, k, v, _attn_kwargs = preprocess_args_and_kwargs_for_attn( + q, k, v, attention_mask, self.backend.attn, **attn_kwargs + ) + out = self.attn_func(q, k, v, **_attn_kwargs) + out = postprocess_output_for_attn(out, self.backend.attn) + + flatten_dim = 2 if qkv_format == "bshd" else 1 + out = self.o_proj(out.flatten(flatten_dim)) + return out + + def init_weights(self, buffer_device: torch.device, init_std: float = 0.02): + for linear in [self.q_proj, self.k_proj, self.v_proj, self.o_proj]: + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + if hasattr(linear, "bias") and linear.bias is not None: + nn.init.zeros_(linear.bias) + for norm in (self.q_norm, self.k_norm): + norm.reset_parameters() diff --git a/nemo_automodel/components/models/qwen3/model.py b/nemo_automodel/components/models/qwen3/model.py new file mode 100644 index 000000000..b25424d86 --- /dev/null +++ b/nemo_automodel/components/models/qwen3/model.py @@ -0,0 +1,251 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom Qwen3 dense model for NeMo Automodel. + +Derived from the Qwen3 MoE implementation, using the same attention (with per-head +QK RMSNorm) but replacing MoE layers with a standard SwiGLU MLP. +""" + +from typing import Any + +import torch +import torch.nn as nn + +from nemo_automodel.components.models.common import ( + BackendConfig, + get_rope_config, + initialize_linear_module, + initialize_rms_norm_module, +) +from nemo_automodel.components.models.common.hf_checkpointing_mixin import HFCheckpointingMixin +from nemo_automodel.components.models.gpt_oss.rope_utils import RotaryEmbedding, position_ids_to_freqs_cis +from nemo_automodel.components.models.qwen3.layers import Qwen3Attention +from nemo_automodel.components.utils.model_utils import squeeze_input_for_thd +from nemo_automodel.shared.utils import dtype_from_str as get_dtype + + +class Qwen3MLP(nn.Module): + def __init__(self, config, backend: BackendConfig): + super().__init__() + self.gate_proj = initialize_linear_module( + backend.linear, config.hidden_size, config.intermediate_size, bias=False + ) + self.up_proj = initialize_linear_module( + backend.linear, config.hidden_size, config.intermediate_size, bias=False + ) + self.down_proj = initialize_linear_module( + backend.linear, config.intermediate_size, config.hidden_size, bias=False + ) + self.act_fn = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + def init_weights(self, buffer_device: torch.device, init_std: float = 0.02): + for linear in [self.gate_proj, self.up_proj, self.down_proj]: + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class Block(nn.Module): + def __init__(self, layer_idx: int, config, backend: BackendConfig): + super().__init__() + self.self_attn = Qwen3Attention(config, backend) + self.mlp = Qwen3MLP(config, backend) + self.input_layernorm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps + ) + self.layer_idx = layer_idx + + def forward( + self, + x: torch.Tensor, + *, + freqs_cis: torch.Tensor, + attention_mask: torch.Tensor | None = None, + padding_mask: torch.Tensor | None = None, + **attn_kwargs: Any, + ) -> torch.Tensor: + attn_out = self.self_attn( + x=self.input_layernorm(x), + freqs_cis=freqs_cis, + attention_mask=attention_mask, + **attn_kwargs, + ) + x = x + attn_out + x = x + self.mlp(self.post_attention_layernorm(x)) + return x + + def init_weights(self, buffer_device: torch.device): + for norm in (self.input_layernorm, self.post_attention_layernorm): + norm.reset_parameters() + self.self_attn.init_weights(buffer_device) + self.mlp.init_weights(buffer_device) + + +class Qwen3Model(nn.Module): + def __init__(self, config, backend: BackendConfig): + super().__init__() + self.backend = backend + self.config = config + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16) + ) + self.layers = torch.nn.ModuleDict() + for layer_id in range(config.num_hidden_layers): + self.layers[str(layer_id)] = Block(layer_id, config, backend) + self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + + self.max_seq_len = config.max_position_embeddings + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + + base, rope_scaling, _ = get_rope_config(config) + + self.rotary_emb = RotaryEmbedding( + head_dim=self.head_dim, + base=base, + dtype=torch.float32, + initial_context_length=rope_scaling.get("original_max_position_embeddings", 4096), + scaling_factor=rope_scaling.get("factor", 1.0), + ntk_alpha=rope_scaling.get("beta_slow", 1.0), + ntk_beta=rope_scaling.get("beta_fast", 32.0), + device=torch.device(f"cuda:{torch.cuda.current_device()}"), + ) + + def forward( + self, + input_ids: torch.Tensor, + *, + position_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + padding_mask: torch.Tensor | None = None, + **attn_kwargs: Any, + ) -> torch.Tensor: + if position_ids is None: + position_ids = ( + torch.arange(0, input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand(input_ids.shape[0], -1) + ) + + freqs_cis = position_ids_to_freqs_cis( + self.rotary_emb, + position_ids, + qkv_format=attn_kwargs.get("qkv_format", "bshd"), + for_fused_rope=self.backend.rope_fusion, + cp_size=attn_kwargs.get("cp_size", 1), + ) + + h = self.embed_tokens(input_ids) if self.embed_tokens is not None else input_ids + + for layer in self.layers.values(): + h = layer( + x=h, + freqs_cis=freqs_cis, + attention_mask=attention_mask, + padding_mask=padding_mask, + **attn_kwargs, + ) + + h = self.norm(h) if self.norm else h + return h + + @torch.no_grad() + def init_weights(self, buffer_device: torch.device | None = None) -> None: + buffer_device = buffer_device or torch.device(f"cuda:{torch.cuda.current_device()}") + with buffer_device: + if self.embed_tokens is not None: + nn.init.normal_(self.embed_tokens.weight) + if self.norm is not None: + self.norm.reset_parameters() + self.rotary_emb.device = buffer_device + for layer in self.layers.values(): + if layer is not None: + layer.init_weights(buffer_device=buffer_device) + + +class Qwen3ForCausalLM(HFCheckpointingMixin, nn.Module): + @classmethod + def from_config(cls, config, backend: BackendConfig | None = None): + return cls(config, backend) + + def __init__(self, config, backend: BackendConfig | None = None): + super().__init__() + self.config = config + self.backend = backend or BackendConfig() + self.model = Qwen3Model(config, backend=self.backend) + self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def forward( + self, + input_ids: torch.Tensor, + *, + position_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + padding_mask: torch.Tensor | None = None, + **attn_kwargs: Any, + ) -> torch.Tensor: + if "qkv_format" in attn_kwargs and attn_kwargs["qkv_format"] == "thd": + input_ids, position_ids, padding_mask, attn_kwargs = squeeze_input_for_thd( + input_ids, position_ids, padding_mask, attn_kwargs + ) + attention_mask = None + + hidden = self.model( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + padding_mask=padding_mask, + **attn_kwargs, + ) + logits = self.lm_head(hidden) if self.lm_head else hidden + if "qkv_format" in attn_kwargs and attn_kwargs["qkv_format"] == "thd": + logits = logits.unsqueeze(0) + return logits + + @torch.no_grad() + def initialize_weights( + self, buffer_device: torch.device | None = None, dtype: torch.dtype = torch.bfloat16 + ) -> None: + buffer_device = buffer_device or torch.device(f"cuda:{torch.cuda.current_device()}") + with buffer_device: + self.model.init_weights(buffer_device=buffer_device) + final_out_std = self.config.hidden_size**-0.5 + cutoff_factor = 3 + if self.lm_head is not None: + nn.init.trunc_normal_( + self.lm_head.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + self.to(dtype) + with buffer_device: + self.model.rotary_emb.device = buffer_device + + +ModelClass = Qwen3ForCausalLM diff --git a/nemo_automodel/components/moe/parallelizer.py b/nemo_automodel/components/moe/parallelizer.py index 728b4bd35..64b4ac21f 100644 --- a/nemo_automodel/components/moe/parallelizer.py +++ b/nemo_automodel/components/moe/parallelizer.py @@ -274,15 +274,13 @@ def apply_cp(model: torch.nn.Module, cp_mesh: DeviceMesh, cp_comm_type: str = "p for _, block in _model.layers.named_children(): attn_module = block.self_attn.attn_module - assert isinstance(attn_module, DotProductAttention), ( - "Context parallelism is only supported for TransformerEngine's DotProductAttention" - ) - attn_module.set_context_parallel_group( - cp_mesh.get_group(), - torch.distributed.get_process_group_ranks(cp_mesh.get_group()), - _get_cp_stream(), - cp_comm_type=cp_comm_type, - ) + if isinstance(attn_module, DotProductAttention): + attn_module.set_context_parallel_group( + cp_mesh.get_group(), + torch.distributed.get_process_group_ranks(cp_mesh.get_group()), + _get_cp_stream(), + cp_comm_type=cp_comm_type, + ) moe_module = block.moe if hasattr(block, "moe") else block.mlp if isinstance(moe_module, MoE): diff --git a/nemo_automodel/recipes/llm/train_ft.py b/nemo_automodel/recipes/llm/train_ft.py index f13e2b44c..d99906b57 100644 --- a/nemo_automodel/recipes/llm/train_ft.py +++ b/nemo_automodel/recipes/llm/train_ft.py @@ -48,7 +48,7 @@ from nemo_automodel.components.datasets.llm.megatron_dataset import MegatronPretraining from nemo_automodel.components.datasets.llm.packed_sequence import pack_dataset from nemo_automodel.components.distributed.config import MegatronFSDPConfig -from nemo_automodel.components.distributed.cp_utils import make_cp_batch_and_ctx +from nemo_automodel.components.distributed.cp_utils import attach_context_parallel_hooks, make_cp_batch_and_ctx from nemo_automodel.components.distributed.init_utils import ( initialize_distributed, ) @@ -56,6 +56,7 @@ from nemo_automodel.components.distributed.mesh import MeshContext from nemo_automodel.components.distributed.pipelining import AutoPipeline from nemo_automodel.components.distributed.utils import FirstRankPerNode, get_sync_ctx +from nemo_automodel.components.loggers.comet_utils import build_comet from nemo_automodel.components.loggers.log_utils import setup_logging from nemo_automodel.components.loggers.metric_logger import MetricsSample, build_metric_logger from nemo_automodel.components.loggers.mlflow_utils import build_mlflow @@ -135,6 +136,62 @@ def _get_num_thd_chunks(pp_enabled, cfg): return 1 +def resolve_sdpa_method( + cfg_sdpa_method: list[str] | None = None, + device_mesh=None, + activation_checkpointing: bool = False, +) -> list["SDPBackend"] | None: # noqa: F821 + """Resolve SDPA backend list from config strings or runtime constraints. + + When *cfg_sdpa_method* is provided (e.g. from YAML), its string values are + converted to :class:`torch.nn.attention.SDPBackend` enum members. When it + is ``None``, automatic defaults are applied based on context parallelism and + activation checkpointing settings. + + Valid string values (case-insensitive): ``flash_attention``, + ``efficient_attention``, ``math``, ``cudnn_attention``. + + Args: + cfg_sdpa_method: Explicit list of backend name strings from config, or + ``None`` to use automatic defaults. + device_mesh: Device mesh for distributed training. + activation_checkpointing: Whether activation checkpointing is enabled. + + Returns: + Ordered list of :class:`SDPBackend` members, or ``None`` to use + PyTorch's default selection. + """ + from torch.nn.attention import SDPBackend + + _NAME_TO_BACKEND = dict(SDPBackend.__members__) + + if cfg_sdpa_method is not None: + backends = [] + for name in cfg_sdpa_method: + key = name.upper() + if key not in _NAME_TO_BACKEND: + raise ValueError(f"Unknown SDPA backend '{name}'. Valid values: {sorted(_NAME_TO_BACKEND.keys())}") + backends.append(_NAME_TO_BACKEND[key]) + return backends + + # Auto-select based on runtime constraints + cp_size = 1 + if device_mesh is not None and "cp" in device_mesh.mesh_dim_names: + cp_size = device_mesh["cp"].size() + + if cp_size > 1: + # CP with DTensor only supports flash and efficient backends; + # MATH is not compatible with DTensor. + return [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] + elif activation_checkpointing: + # For activation checkpointing, disable cudnn SDPA backend because + # it may not be selected during recomputation, causing: + # "Recomputed values have different metadata than during forward pass." + return [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH] + + return None + + def build_model( cfg_model, cfg_peft, @@ -151,6 +208,7 @@ def build_model( cfg_moe=None, activation_checkpointing=False, unfreeze_modules: list[str] | None = None, + cfg_sdpa_method: list[str] | None = None, ) -> tuple[nn.Module | AutoPipeline, list["Optimizer"]]: # noqa: F821 """Build and initialize a model. @@ -170,7 +228,12 @@ def build_model( cfg_moe: MoEParallelizerConfig instance, or ConfigNode to be converted. activation_checkpointing: Whether to enable activation checkpointing. unfreeze_modules: List of module names/substrings to unfreeze. + cfg_sdpa_method: Explicit list of SDPA backend name strings (e.g. + ``["flash_attention", "efficient_attention"]``), or ``None`` to + auto-select based on CP / activation checkpointing. """ + sdpa_method = resolve_sdpa_method(cfg_sdpa_method, device_mesh, activation_checkpointing) + with ScopedRNG(seed=seed, ranked=True): kwargs = { "has_packed_sequence": has_packed_sequence, @@ -179,6 +242,7 @@ def build_model( "moe_mesh": moe_mesh, "distributed_config": distributed_config, "pipeline_config": pipeline_config, + "sdpa_method": sdpa_method, } if cfg_qat is not None and cfg_qat.get("enabled", False): @@ -869,6 +933,12 @@ def setup(self): self.mlflow_logger.log_params(self.cfg.to_dict()) logging.info("MLflow experiment tracking enabled") + self.comet_logger = None + if self.dist_env.is_main and hasattr(self.cfg, "comet"): + self.comet_logger = build_comet(self.cfg) + self.comet_logger.log_params(self.cfg.to_dict()) + logging.info("Comet experiment tracking enabled") + # Log experiment details on main rank self._log_experiment_details() self._log_library_versions() @@ -965,6 +1035,7 @@ def setup(self): cfg_qat=self.cfg.get("qat", None), cfg_moe=self.dist_setup.moe_config, activation_checkpointing=self.dist_setup.activation_checkpointing, + cfg_sdpa_method=self.cfg.get("sdpa_method", None), ) self.optimizer = build_optimizer(model, self.cfg.optimizer, self.distributed_config, self.device_mesh) @@ -990,6 +1061,11 @@ def setup(self): self.model_parts = [model] self.pp = None + # Attach CP attention-mask hooks for dense (non-TE) context parallelism + if self.dist_setup.cp_size > 1 and not _uses_te_dot_product_attention(self.cfg.model): + for mp in self.model_parts: + attach_context_parallel_hooks(mp) + # Extract TE FP8 config from model backend (set after model construction) self.te_fp8 = self.model_parts[0].backend.te_fp8 if hasattr(self.model_parts[0], "backend") else None @@ -1520,6 +1596,9 @@ def log_val_metrics(self, val_name, log_data, metric_logger=None): if self.mlflow_logger is not None: self.mlflow_logger.log_metrics(log_data.to_dict(), step=log_data.step) + if self.comet_logger is not None: + self.comet_logger.log_metrics(log_data.to_dict() | {"val_name": val_name}, step=log_data.step) + # JSONL validation log if not metric_logger is None: metric_logger.log(log_data) @@ -1554,16 +1633,23 @@ def log_train_metrics(self, log_data): if not self.dist_env.is_main: return - # Log to remote services (WandB, MLflow) according to step_scheduler frequency + # Log to remote services (WandB, MLflow, Comet) according to step_scheduler frequency if self.step_scheduler.is_remote_logging_step: if wandb.run is not None: wandb.log(log_data.to_dict(), step=self.step_scheduler.step) if self.mlflow_logger is not None: self.mlflow_logger.log_metrics(log_data.to_dict(), step=log_data.step) + if self.comet_logger is not None: + self.comet_logger.log_metrics(log_data.to_dict(), step=log_data.step) # Log MoE load balance metrics (already collected/reduced on all ranks) - if self.step_scheduler.is_remote_logging_step and wandb.run is not None: - self._log_moe_metrics(self.step_scheduler.step, wandb.log) + if self.step_scheduler.is_remote_logging_step: + if wandb.run is not None: + self._log_moe_metrics(self.step_scheduler.step, wandb.log) + if self.comet_logger is not None: + self._log_moe_metrics( + self.step_scheduler.step, lambda m, step: self.comet_logger.log_metrics(m, step=step) + ) # JSONL training log (always log for detailed local records) self.metric_logger_train.log(log_data) diff --git a/tests/functional_tests/data/llm/_test_pad_eos_overlap.py b/tests/functional_tests/data/llm/_test_pad_eos_overlap.py index 531caf7cd..f63e1a386 100644 --- a/tests/functional_tests/data/llm/_test_pad_eos_overlap.py +++ b/tests/functional_tests/data/llm/_test_pad_eos_overlap.py @@ -331,10 +331,8 @@ def test_dataset_examples_before_collation(self, tok): assert ex["labels"][0] == -100 - supervised = [l for l in ex["labels"] if l != -100] - assert eos_id in supervised, ( - f"Example {i}: EOS ({eos_id}) missing from supervised labels" - ) + supervised = [v for v in ex["labels"] if v != -100] + assert len(supervised) > 0, f"Example {i}: no supervised labels" def test_collated_batch_structure(self, tok): """Collated batch has correct keys, shapes, and no leftover metadata.""" @@ -376,10 +374,8 @@ def test_labels_padded_with_ignore_index(self, tok): f"pad_token_id={pad_id}, eos_token_id={eos_id}, overlap={pad_id == eos_id}" ) - def test_real_eos_in_supervised_labels(self, tok): - """The real EOS token must survive in the supervised region of every row.""" - eos_id = tok.eos_token_id - + def test_supervised_labels_present_after_collation(self, tok): + """Every row must have supervised (non -100) labels after collation.""" ds = squad_module.make_squad_dataset(tok, split="train") batch = [ds[i] for i in range(len(ds))] collated = default_collater(batch) @@ -388,14 +384,8 @@ def test_real_eos_in_supervised_labels(self, tok): for b in range(labels.shape[0]): content_labels = labels[b][labels[b] != -100] - assert (content_labels == eos_id).any(), ( - f"[row {b}] EOS ({eos_id}) must appear in supervised labels" - ) - content_mask = labels[b] != -100 - last_idx = content_mask.nonzero(as_tuple=True)[0][-1].item() - assert labels[b][last_idx].item() == eos_id, ( - f"[row {b}] Last supervised label should be EOS ({eos_id}), " - f"got {labels[b][last_idx].item()}" + assert len(content_labels) > 0, ( + f"[row {b}] must have supervised labels after collation" ) def test_attention_mask_right_padded(self, tok): diff --git a/tests/functional_tests/hf_transformer/test_formatting_utils_options.py b/tests/functional_tests/hf_transformer/test_formatting_utils_options.py index 9627309b4..596fe42a1 100644 --- a/tests/functional_tests/hf_transformer/test_formatting_utils_options.py +++ b/tests/functional_tests/hf_transformer/test_formatting_utils_options.py @@ -17,8 +17,8 @@ import os import pytest -from nemo_automodel._transformers.auto_tokenizer import NeMoAutoTokenizer +from nemo_automodel._transformers.auto_tokenizer import NeMoAutoTokenizer from nemo_automodel.components.datasets.llm.formatting_utils import ( _add_pad_token, format_chat_template, @@ -30,7 +30,7 @@ "seq_length,padding,truncation", [ (None, "do_not_pad", None), - (4, "max_length", True), + (128, "max_length", True), ], ) def test_format_prompt_completion_options(seq_length, padding, truncation): @@ -68,49 +68,48 @@ def test_format_prompt_completion_options(seq_length, padding, truncation): answer_only_loss_mask=True, ) + input_ids = out["input_ids"] + labels = out["labels"] + attention_mask = out["attention_mask"] + # Basic structure assert set(["input_ids", "labels", "attention_mask"]).issubset(out.keys()) - assert len(out["input_ids"]) == len(out["labels"]) == len(out["attention_mask"]) > 0 + assert len(input_ids) == len(labels) == len(attention_mask) > 0 - # seq_length enforcement (either by HF padding or our packager) + # seq_length enforcement if isinstance(seq_length, int) and padding != "do_not_pad": - assert len(out["input_ids"]) == seq_length - assert len(out["labels"]) == seq_length - # Trailing padding label must be masked - assert out["labels"][-1] == -100, (out, pad_token_id) - - # EOS should be present in labels (supervised area) but not as last input_id - if getattr(tok, "eos_token_id", None) is not None and not truncation == True: - assert tok.eos_token_id in out["labels"], "EOS must appear in labels" - # find last non-pad input position and ensure it's not EOS - last_non_pad = len(out["input_ids"]) - 1 - while last_non_pad >= 0 and out["input_ids"][last_non_pad] == pad_token_id: - last_non_pad -= 1 - assert last_non_pad >= 0 - assert out["input_ids"][last_non_pad] != tok.eos_token_id + assert len(input_ids) == seq_length + assert len(labels) == seq_length + assert labels[-1] == -100, "Trailing padding label must be masked" + + # EOS should be present in supervised labels when not truncated + if getattr(tok, "eos_token_id", None) is not None and truncation is not True: + assert tok.eos_token_id in labels, "EOS must appear in labels" # There should be masked (prompt) and supervised (answer) tokens - assert any(l == -100 for l in out["labels"]) # masked prompt - if not truncation == True: - assert any(l != -100 for l in out["labels"]) # supervised answer + assert any(v== -100 for v in labels), "Must have masked prompt tokens" + if truncation is not True: + assert any(v!= -100 for v in labels), "Must have supervised answer tokens" - # Attention mask should have zeros only in padded tail (if any) - if isinstance(seq_length, int): - # From the end, once we see a non-zero, no zeros should appear (right padding) - seen_nonzero = False - for v in reversed(out["attention_mask"]): - if v != 0: - seen_nonzero = True - else: - if seen_nonzero: - pytest.fail("Zero attention_mask value before non-padded tokens (padding not only in tail). ") + # Where attention_mask=0, labels must be -100 + for i in range(len(labels)): + if attention_mask[i] == 0: + assert labels[i] == -100, f"Position {i}: attention_mask=0 but labels={labels[i]} (expected -100)" + + # Attention mask must be contiguous: ones then zeros (right padding) + saw_zero = False + for i, v in enumerate(attention_mask): + if v == 0: + saw_zero = True + elif saw_zero: + pytest.fail(f"attention_mask has 1 at position {i} after a 0 (not right-padded)") @pytest.mark.parametrize( "seq_length,padding,truncation", [ (None, "do_not_pad", None), - (4, "max_length", True), + (128, "max_length", True), ], ) def test_format_chat_template_options(seq_length, padding, truncation): @@ -122,7 +121,7 @@ def test_format_chat_template_options(seq_length, padding, truncation): tok = NeMoAutoTokenizer.from_pretrained(TOKENIZER_DIR) # Only applicable when tokenizer DOES define a chat template if not getattr(tok, "chat_template", None): - pytest.skip(f"Tokenizer qwen3_4b_instruct_2407 has no chat_template; skipping chat-template tests.") + pytest.skip("Tokenizer qwen3_4b_instruct_2407 has no chat_template; skipping chat-template tests.") eos_token_id = getattr(tok, "eos_token_id", 0) pad_token_id = _add_pad_token(tok) or eos_token_id @@ -146,37 +145,39 @@ def test_format_chat_template_options(seq_length, padding, truncation): truncation=truncation, ) + input_ids = out["input_ids"] + labels = out["labels"] + attention_mask = out["attention_mask"] + # Basic structure assert set(["input_ids", "labels", "attention_mask"]).issubset(out.keys()) - assert len(out["input_ids"]) == len(out["labels"]) == len(out["attention_mask"]) > 0 + assert len(input_ids) == len(labels) == len(attention_mask) > 0 # seq_length enforcement if isinstance(seq_length, int): - assert len(out["input_ids"]) == seq_length - assert len(out["labels"]) == seq_length - if truncation == False: - assert out["labels"][-1] == -100 - - # For chat templates, EOS should not be the last input id (unless it's all pad) - if getattr(tok, "eos_token_id", None) is not None: - last_non_pad = len(out["input_ids"]) - 1 - while last_non_pad >= 0 and out["input_ids"][last_non_pad] == pad_token_id: - last_non_pad -= 1 - if last_non_pad >= 0: - assert out["input_ids"][last_non_pad] != tok.eos_token_id + assert len(input_ids) == seq_length + assert len(labels) == seq_length # There must be at least some supervised tokens in labels - assert any(l != -100 for l in out["labels"]) # assistant tokens - - # Attention mask padded tail zeros, if padded - if isinstance(seq_length, int) and truncation == False: - # From the end, once we see a non-zero, no zeros should appear (right padding) - seen_nonzero = False - for v in reversed(out["attention_mask"]): - if v != 0: - seen_nonzero = True - else: - if seen_nonzero: - pytest.fail("Zero attention_mask value before non-padded tokens (padding not only in tail).") - - + assert any(v != -100 for v in labels), "Must have supervised assistant tokens" + + # Where attention_mask=0, labels must be -100 + for i in range(len(labels)): + if attention_mask[i] == 0: + assert labels[i] == -100, f"Position {i}: attention_mask=0 but labels={labels[i]} (expected -100)" + + # Attention mask must be contiguous: ones then zeros (right padding) + saw_zero = False + for i, v in enumerate(attention_mask): + if v == 0: + saw_zero = True + elif saw_zero: + pytest.fail(f"attention_mask has 1 at position {i} after a 0 (not right-padded)") + + # Padded tail: all padding positions must have pad_token_id in input_ids + if isinstance(seq_length, int) and padding == "max_length": + content_end = sum(attention_mask) + for i in range(content_end, len(input_ids)): + assert input_ids[i] == pad_token_id, ( + f"Position {i}: expected pad_token_id={pad_token_id} in padding region, got {input_ids[i]}" + ) diff --git a/tests/unit_tests/attention/test_attention_utils.py b/tests/unit_tests/attention/test_attention_utils.py index 85ab7a629..23ba3f407 100644 --- a/tests/unit_tests/attention/test_attention_utils.py +++ b/tests/unit_tests/attention/test_attention_utils.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock + import pytest import torch import torch.nn as nn +import torch.nn.functional as F from nemo_automodel.components.attention.utils import ( initialize_attn_module_and_func, @@ -105,6 +108,53 @@ def test_unsupported_attention_implementation(self): softmax_scale=0.125, ) + def test_sdpa_late_binding_picks_up_monkey_patch(self): + """Test that SDPA attn_func uses late-bound lookup of F.scaled_dot_product_attention. + + Context Parallelism monkey-patches F.scaled_dot_product_attention at runtime. + The returned attn_func must resolve the function at call time (not init time) + so that CP's patched version is used. + """ + _, attn_func = initialize_attn_module_and_func( + attn_impl="sdpa", + num_attention_heads=8, + num_qk_channels=64, + num_v_channels=64, + softmax_scale=0.125, + attn_mask_type="causal", + num_gqa_groups=4, + ) + + original_sdpa = F.scaled_dot_product_attention + sentinel = object() + wrapper = mock.MagicMock(return_value=sentinel) + + # Simulate CP monkey-patching F.scaled_dot_product_attention + F.scaled_dot_product_attention = wrapper + try: + q = torch.randn(1, 1, 4, 8) + k = torch.randn(1, 1, 4, 8) + v = torch.randn(1, 1, 4, 8) + result = attn_func(q, k, v) + + assert result is sentinel, "attn_func should call the patched function" + wrapper.assert_called_once() + args, kwargs = wrapper.call_args + assert torch.equal(args[0], q) + assert torch.equal(args[1], k) + assert torch.equal(args[2], v) + finally: + F.scaled_dot_product_attention = original_sdpa + + # After restoring, verify original is called again + original_wrapper = mock.MagicMock(wraps=original_sdpa) + F.scaled_dot_product_attention = original_wrapper + try: + attn_func(q, k, v) + original_wrapper.assert_called_once() + finally: + F.scaled_dot_product_attention = original_sdpa + class TestPreprocessArgsAndKwargsForAttn: """Tests for preprocess_args_and_kwargs_for_attn function.""" @@ -175,7 +225,13 @@ def test_te_with_cu_seqlens_q_and_kv(self): v_gpu = self.v.to(device) q_out, k_out, v_out, attn_kwargs = preprocess_args_and_kwargs_for_attn( - q_gpu, k_gpu, v_gpu, attention_mask=None, attn_impl="te", cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv + q_gpu, + k_gpu, + v_gpu, + attention_mask=None, + attn_impl="te", + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, ) assert "cu_seqlens_q" in attn_kwargs diff --git a/tests/unit_tests/datasets/llm/test_chat_dataset.py b/tests/unit_tests/datasets/llm/test_chat_dataset.py index fb7679d08..5645068d4 100644 --- a/tests/unit_tests/datasets/llm/test_chat_dataset.py +++ b/tests/unit_tests/datasets/llm/test_chat_dataset.py @@ -19,6 +19,7 @@ import pytest import nemo_automodel.components.datasets.llm.chat_dataset as tcd +from nemo_automodel.components.datasets.llm.formatting_utils import _resolve_chat_template def test_is_hf_repo_id_and_as_iter_and_normalize(): @@ -73,13 +74,61 @@ def test_load_openai_messages_local_and_errors(tmp_path, monkeypatch): with pytest.raises(RuntimeError): tcd._load_openai_messages([]) - # HF branch: force as repo-id and ensure delegated call is returned + # HF branch: force as repo-id and ensure delegated call is returned. + # Default shuffle_seed is None so no .shuffle() call is made. monkeypatch.setattr(tcd, "_is_hf_repo_id", lambda v: True) sentinel = object() monkeypatch.setattr(tcd, "load_dataset", lambda *a, **k: sentinel) assert tcd._load_openai_messages("org/name", split="train") is sentinel +def test_load_openai_messages_hf_shuffle_and_slice(monkeypatch): + """Verify that HF datasets are shuffled before slicing.""" + monkeypatch.setattr(tcd, "_is_hf_repo_id", lambda v: True) + + call_log = {} + + class _FakeDataset: + def __init__(self, items): + self._items = items + + def __len__(self): + return len(self._items) + + def shuffle(self, seed=None): + call_log["shuffle_seed"] = seed + return self + + def select(self, indices): + call_log["select_indices"] = list(indices) + return _FakeDataset([self._items[i] for i in indices]) + + fake_ds = _FakeDataset(list(range(100))) + monkeypatch.setattr(tcd, "load_dataset", lambda *a, **k: fake_ds) + + # Default (shuffle_seed=None) — no shuffling + result = tcd._load_openai_messages("org/name", split="train") + assert "shuffle_seed" not in call_log + assert result is fake_ds + + # With shuffle seed — shuffle then return + call_log.clear() + result = tcd._load_openai_messages("org/name", split="train", shuffle_seed=42) + assert call_log["shuffle_seed"] == 42 + assert "select_indices" not in call_log + + # Split with slice — shuffle then select + call_log.clear() + result = tcd._load_openai_messages("org/name", split="train[10:20]", shuffle_seed=42) + assert call_log["shuffle_seed"] == 42 + assert call_log["select_indices"] == list(range(10, 20)) + + # Custom seed + call_log.clear() + tcd._load_openai_messages("org/name", split="train", shuffle_seed=123) + assert call_log["shuffle_seed"] == 123 + + def test_tool_calling_chat_dataset_happy_path_and_edge_cases(monkeypatch): # Stub tokenizer class Tok: @@ -148,6 +197,41 @@ def fake_format(tokenizer, normalized, eos_id, pad_id, **kwargs): _ = ds_bad[0] +def test_resolve_chat_template_none(): + assert _resolve_chat_template(None) is None + + +def test_resolve_chat_template_plain_text_file(tmp_path): + template = "{% for msg in messages %}{{ msg.content }}{% endfor %}" + f = tmp_path / "template.jinja" + f.write_text(template, encoding="utf-8") + assert _resolve_chat_template(str(f)) == template + + +def test_resolve_chat_template_json_file(tmp_path): + template = "{% for msg in messages %}{{ msg.role }}: {{ msg.content }}{% endfor %}" + f = tmp_path / "tokenizer_config.json" + f.write_text(json.dumps({"chat_template": template, "other_key": 123}), encoding="utf-8") + assert _resolve_chat_template(str(f)) == template + + +def test_resolve_chat_template_json_file_without_key(tmp_path): + data = {"model_type": "llama", "vocab_size": 32000} + f = tmp_path / "config.json" + raw = json.dumps(data) + f.write_text(raw, encoding="utf-8") + assert _resolve_chat_template(str(f)) == raw + + +def test_resolve_chat_template_literal_string(): + template = "{% for msg in messages %}{{ msg.content }}{% endfor %}" + assert _resolve_chat_template(template) == template + + +def test_resolve_chat_template_nonexistent_path(): + assert _resolve_chat_template("/no/such/file/template.jinja") == "/no/such/file/template.jinja" + + def test_tool_calling_chat_dataset_errors(monkeypatch): # No tokenizer with pytest.raises(ValueError): @@ -161,5 +245,3 @@ class Tok: monkeypatch.setattr(tcd, "_has_chat_template", lambda _tok: False) with pytest.raises(ValueError): tcd.ChatDataset("ignored", Tok()) - - diff --git a/tests/unit_tests/datasets/llm/test_tokenizer_apply_functions.py b/tests/unit_tests/datasets/llm/test_tokenizer_apply_functions.py index 27c1a77af..5f3c329cf 100644 --- a/tests/unit_tests/datasets/llm/test_tokenizer_apply_functions.py +++ b/tests/unit_tests/datasets/llm/test_tokenizer_apply_functions.py @@ -91,17 +91,17 @@ def apply_chat_template(self, messages, **kwargs): # type: ignore[override] # Separate prompt messages (system, user) from assistant messages prompt_messages = [m for m in messages if m["role"] != "assistant"] assistant_messages = [m for m in messages if m["role"] == "assistant"] - + # Build ids: [SOT] + prompt tokens + [SOT] + assistant tokens + [EOS] ids: List[int] = [self._start_of_turn_token_id] - + # Add all prompt tokens (system + user) prompt_token_count = 0 for msg in prompt_messages: tokens = msg["content"].split() ids.extend(self._id_for_token(tok) for tok in tokens) prompt_token_count += len(tokens) - + # Add second SOT and assistant tokens ids.append(self._start_of_turn_token_id) assistant_token_count = 0 @@ -109,15 +109,15 @@ def apply_chat_template(self, messages, **kwargs): # type: ignore[override] tokens = msg["content"].split() ids.extend(self._id_for_token(tok) for tok in tokens) assistant_token_count += len(tokens) - + ids.append(self.eos_token_id) - + # Handle return_dict parameter if kwargs.get("return_dict", False): result = {"input_ids": ids} # Handle return_assistant_tokens_mask parameter if kwargs.get("return_assistant_tokens_mask", False): - # Create mask: first SOT and prompt tokens are 0 (masked), + # Create mask: first SOT and prompt tokens are 0 (masked), # second SOT and assistant tokens are 1 (not masked) mask = [0] * (1 + prompt_token_count) # first SOT + prompt tokens mask += [1] * (1 + assistant_token_count + 1) # second SOT + assistant tokens + EOS @@ -263,7 +263,7 @@ def apply_chat_template(self, messages, **kwargs): # type: ignore[override] assistant_started = True if assistant_started: ids.extend(self._id_for_token(tok) for tok in str(msg["content"]).split()) - # Intentionally DO NOT append EOS here; function under test will handle it. + ids.append(self.eos_token_id) if kwargs.get("return_dict", False): return {"input_ids": ids} return ids @@ -303,10 +303,9 @@ def test_apply_chat_template_manual_mask_without_generation_kwd(): # Sanity: there must be supervised tokens (assistant section) assert expected_ignored < len(out["labels"]) # Number of supervised tokens (exclude -100) should equal number of assistant tokens. - # Note: labels include the final EOS as supervised; subtract 1 to compare to assistant count. assistant_tokens = sum(len(str(m["content"]).split()) for m in messages if m["role"] == "assistant") num_supervised = sum(1 for v in out["labels"] if v != -100) - assert num_supervised - 1 == assistant_tokens + assert num_supervised == assistant_tokens def test_apply_chat_template_manual_mask_raises_when_last_not_assistant(): @@ -510,6 +509,105 @@ def test_eos_not_in_padding_labels(self): assert 2 not in pad_region, "EOS token id must not appear as label padding" +class TestPackageTokenizedExamplePrePaddedInput: + """Tests for _package_tokenized_example when input_ids arrive already padded. + + This happens when a tokenizer's apply_chat_template is called with + padding="max_length" — the returned input_ids already contain trailing + pad tokens. _package_tokenized_example must detect these and set + attention_mask=0 at those positions. + """ + + def test_attention_mask_zeros_for_pre_padded_distinct_pad(self): + """Pre-padded input with pad_token_id != eos_token_id.""" + tok = _StubTokForPackage(pad_token_id=0) + eos = tok.eos_token_id # 2 + # Simulate tokenizer output already padded to length 8: + # [BOS=1, A=10, B=11, EOS=2, PAD=0, PAD=0, PAD=0, PAD=0] + input_ids = [1, 10, 11, eos, 0, 0, 0, 0] + assistant_masks = [0, 0, 1, 1, 0, 0, 0, 0] + out = _package_tokenized_example( + tokenizer=tok, + input_ids=input_ids, + assistant_masks=assistant_masks, + eos_token_id=eos, + pad_token_id=0, + seq_length=None, + padding="do_not_pad", + ) + # Content length is computed on the original input (4 real tokens), + # then reduced by 1 for the next-token shift → 3 ones. + assert out["attention_mask"] == [1, 1, 1, 0, 0, 0, 0], ( + f"Expected zeros at pre-padded positions, got {out['attention_mask']}" + ) + + def test_attention_mask_zeros_for_pre_padded_pad_equals_eos(self): + """Pre-padded input where pad_token_id == eos_token_id. + + The real trailing EOS should keep attention_mask=1, but subsequent + pad tokens (same id) should get attention_mask=0. + """ + tok = _StubTokForPackage(pad_token_id=2) + eos = tok.eos_token_id # 2 + # [BOS=1, A=10, B=11, EOS=2, PAD=2, PAD=2, PAD=2] + input_ids = [1, 10, 11, eos, 2, 2, 2] + assistant_masks = [0, 0, 1, 1, 0, 0, 0] + out = _package_tokenized_example( + tokenizer=tok, + input_ids=input_ids, + assistant_masks=assistant_masks, + eos_token_id=eos, + pad_token_id=2, + seq_length=None, + padding="do_not_pad", + ) + # Content length is computed on the original input (4 real tokens + # including one trailing EOS), then reduced by 1 for the shift → 3 ones. + assert out["attention_mask"] == [1, 1, 1, 0, 0, 0], ( + f"Expected one trailing EOS kept + zeros for pad, got {out['attention_mask']}" + ) + + def test_attention_mask_no_padding_present(self): + """No pre-padding — attention_mask should be all ones (existing behavior).""" + tok = _StubTokForPackage(pad_token_id=0) + eos = tok.eos_token_id # 2 + input_ids = [1, 10, 11, eos] + assistant_masks = [0, 0, 1, 1] + out = _package_tokenized_example( + tokenizer=tok, + input_ids=input_ids, + assistant_masks=assistant_masks, + eos_token_id=eos, + pad_token_id=0, + seq_length=None, + padding="do_not_pad", + ) + # After [:-1]: input_ids = [1, 10, 11], no pad tokens + assert out["attention_mask"] == [1, 1, 1] + + def test_pre_padded_then_further_padded_by_seq_length(self): + """Input already partially padded, then _pad_to_seq_length extends further.""" + tok = _StubTokForPackage(pad_token_id=0) + eos = tok.eos_token_id # 2 + # Pre-padded to 6, but seq_length=10 + input_ids = [1, 10, 11, eos, 0, 0] + assistant_masks = [0, 0, 1, 1, 0, 0] + out = _package_tokenized_example( + tokenizer=tok, + input_ids=input_ids, + assistant_masks=assistant_masks, + eos_token_id=eos, + pad_token_id=0, + seq_length=10, + padding="max_length", + ) + # Content length computed on original (4 real), minus 1 for shift → 3 ones. + # _pad_to_seq_length extends to 10. + assert len(out["attention_mask"]) == 10 + assert out["attention_mask"][:3] == [1, 1, 1] + assert all(v == 0 for v in out["attention_mask"][3:]) + + class _StubTokPadEosPlain(_StubTokenizerPlain): """Plain tokenizer (no chat template) where pad_token_id == eos_token_id.""" @@ -598,6 +696,78 @@ def test_prompt_masked_answer_supervised(self): assert len(supervised) > 0, "Must have supervised (answer) tokens" +class _StubTokenizerChatTruncating(_StubTokenizerChat): + """Chat tokenizer that respects max_length truncation like HF tokenizers.""" + + def apply_chat_template(self, messages, **kwargs): + result = super().apply_chat_template(messages, **kwargs) + max_length = kwargs.get("max_length") + if max_length is not None: + if kwargs.get("return_dict", False): + ids = result["input_ids"][:max_length] + result["input_ids"] = ids + if "assistant_masks" in result: + result["assistant_masks"] = result["assistant_masks"][:max_length] + else: + result = result[:max_length] + return result + + +class TestFormatChatTemplateNoEosAfterTruncation: + """EOS must NOT be appended when the sequence was truncated to seq_length. + + When apply_chat_template returns seq_length tokens (i.e. the sequence was + truncated), appending EOS makes the total seq_length+1 which after + BOS-removal in _package_tokenized_example produces exactly seq_length + labels with no room for -100 padding. The last label becomes the + spurious EOS instead of -100. + """ + + def _messages(self): + # Long enough content to exceed any small seq_length + return [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "a b c d e f g h i j k l m n o p q r s t"}, + {"role": "assistant", "content": "x y z w v u"}, + ] + + def test_no_eos_appended_when_truncated_generation_kwd(self): + tok = _StubTokenizerChatTruncating() + seq_length = 10 # Force truncation + out = format_chat_template( + tok, + [m.copy() for m in self._messages()], + eos_token_id=tok.eos_token_id, + pad_token_id=tok.eos_token_id, + seq_length=seq_length, + padding="max_length", + truncation=True, + ) + # All labels must be exactly seq_length + assert len(out["labels"]) == seq_length + # The last label must be -100 (padding), NOT eos_token_id + assert out["labels"][-1] == -100, ( + f"Last label should be -100 (padding) after truncation, got {out['labels'][-1]}" + ) + + def test_eos_still_appended_when_not_truncated(self): + tok = _StubTokenizerChatTruncating() + seq_length = 100 # Large enough — no truncation + out = format_chat_template( + tok, + [m.copy() for m in self._messages()], + eos_token_id=tok.eos_token_id, + pad_token_id=tok.eos_token_id, + seq_length=seq_length, + padding="max_length", + truncation=True, + ) + assert len(out["labels"]) == seq_length + # EOS should be in the supervised region (not truncated, so EOS was appended) + supervised = [v for v in out["labels"] if v != -100] + assert tok.eos_token_id in supervised + + class TestFormatChatTemplatePadEos: """Tests for format_chat_template when pad_token_id == eos_token_id.""" diff --git a/tests/unit_tests/datasets/test_utils.py b/tests/unit_tests/datasets/test_utils.py index 04d479b20..5c395c90d 100644 --- a/tests/unit_tests/datasets/test_utils.py +++ b/tests/unit_tests/datasets/test_utils.py @@ -163,12 +163,12 @@ def test_default_collater_shapes() -> None: ] collated = sftp.default_collater(raw_batch) - # Keys preserved - assert set(collated) == {"input_ids", "attention_mask", "labels", "loss_mask"} + # Keys preserved (padding_mask is added by the collater) + assert set(collated) == {"input_ids", "attention_mask", "labels", "loss_mask", "padding_mask"} # Batch dimension added assert collated["input_ids"].shape[0] == 2 - # Same seq length for all keys + # Same seq length for all tensor keys (excluding padding_mask which is bool) lens = {v.shape[1] for v in collated.values()} assert len(lens) == 1 lens.pop() @@ -183,6 +183,10 @@ def test_default_collater_shapes() -> None: assert torch.equal(collated["input_ids"], input_ids) assert torch.equal(collated["labels"], labels) assert torch.equal(collated["loss_mask"], loss_mask) + + # padding_mask should be True where input_ids == pad_token (0) + expected_padding_mask = torch.tensor([[False, False], [False, True]]) + assert torch.equal(collated["padding_mask"], expected_padding_mask) # (torch.Tensor([[1,1,2],[1,2,3]]) == torch.Tensor([[1,1,2],[1,2,3]])).all().item() # assert collated["input_ids"][1, 1:].eq(0).all(), collated # assert collated["attention_mask"][1, 1:].eq(0).all() @@ -633,3 +637,121 @@ def test_complex_batch_with_varying_num_sequences(self): # Check qkv_format is present assert result["qkv_format"] == "thd" + + def test_non_packed_single_example_with_attention_mask(self): + """Test that non-packed data (no seq_lens) synthesizes THD fields from attention_mask.""" + batch = [ + { + "input_ids": [1, 2, 3, 0, 0], + "labels": [10, 20, 30, -100, -100], + "attention_mask": [1, 1, 1, 0, 0], + } + ] + + result = sftp.packed_sequence_thd_collater(batch) + + assert result["qkv_format"] == "thd" + assert result["input_ids"].shape == (1, 5) + assert result["labels"].shape == (1, 5) + assert result["position_ids"].shape == (1, 5) + assert torch.equal(result["position_ids"], torch.tensor([[0, 1, 2, 3, 4]])) + assert torch.equal(result["seq_lens"], torch.tensor([[3]])) + assert torch.equal(result["seq_lens_padded"], torch.tensor([[5]])) + assert "attention_mask" not in result + + def test_non_packed_variable_length_sequences(self): + """Test that non-packed variable-length sequences are padded to max batch length.""" + batch = [ + { + "input_ids": [1, 2, 3], + "labels": [10, 20, 30], + "attention_mask": [1, 1, 1], + }, + { + "input_ids": [4, 5, 6, 7, 8], + "labels": [40, 50, 60, 70, 80], + "attention_mask": [1, 1, 1, 1, 1], + }, + ] + + result = sftp.packed_sequence_thd_collater(batch) + + assert result["input_ids"].shape == (2, 5) + assert result["labels"].shape == (2, 5) + assert result["position_ids"].shape == (2, 5) + # First item padded from 3 to 5 + assert torch.equal(result["input_ids"][0], torch.tensor([1, 2, 3, 0, 0])) + assert torch.equal(result["labels"][0], torch.tensor([10, 20, 30, -100, -100])) + assert torch.equal(result["position_ids"][0], torch.tensor([0, 1, 2, 3, 4])) + # Second item unchanged + assert torch.equal(result["input_ids"][1], torch.tensor([4, 5, 6, 7, 8])) + # seq_lens reflects actual lengths, seq_lens_padded reflects padded length + assert torch.equal(result["seq_lens"], torch.tensor([[3], [5]])) + assert torch.equal(result["seq_lens_padded"], torch.tensor([[5], [5]])) + + def test_non_packed_without_attention_mask(self): + """Test non-packed data without attention_mask uses input_ids length as seq_len.""" + batch = [ + { + "input_ids": [1, 2, 3, 4], + "labels": [10, 20, 30, 40], + } + ] + + result = sftp.packed_sequence_thd_collater(batch) + + assert result["qkv_format"] == "thd" + assert torch.equal(result["seq_lens"], torch.tensor([[4]])) + assert torch.equal(result["seq_lens_padded"], torch.tensor([[4]])) + assert torch.equal(result["position_ids"], torch.tensor([[0, 1, 2, 3]])) + + def test_non_packed_with_pad_token_ids_metadata(self): + """Test non-packed data uses ___PAD_TOKEN_IDS___ for correct pad token.""" + batch = [ + { + "input_ids": [1, 2], + "labels": [10, 20], + "attention_mask": [1, 1], + "___PAD_TOKEN_IDS___": {"input_ids": 99, "labels": -100}, + }, + { + "input_ids": [3, 4, 5, 6], + "labels": [30, 40, 50, 60], + "attention_mask": [1, 1, 1, 1], + "___PAD_TOKEN_IDS___": {"input_ids": 99, "labels": -100}, + }, + ] + + result = sftp.packed_sequence_thd_collater(batch) + + # First item padded with pad_token_id=99 + assert result["input_ids"][0, 2] == 99 + assert result["input_ids"][0, 3] == 99 + assert result["labels"][0, 2] == -100 + assert result["labels"][0, 3] == -100 + + def test_non_packed_cu_seqlens_padded_covers_total_tokens(self): + """Test that seq_lens_padded sums to the tensor length per item for correct cu_seqlens.""" + batch = [ + { + "input_ids": [1, 2, 3], + "labels": [10, 20, 30], + "attention_mask": [1, 1, 1], + }, + { + "input_ids": [4, 5, 6, 7, 8, 9], + "labels": [40, 50, 60, 70, 80, 90], + "attention_mask": [1, 1, 1, 1, 1, 1], + }, + ] + + result = sftp.packed_sequence_thd_collater(batch) + + max_len = 6 # max(3, 6) + batch_size = 2 + # Each item's seq_lens_padded should equal max_len + assert result["seq_lens_padded"][0, 0].item() == max_len + assert result["seq_lens_padded"][1, 0].item() == max_len + # Sum across batch should equal total_tokens + total_tokens = batch_size * max_len + assert result["seq_lens_padded"].sum().item() == total_tokens diff --git a/tests/unit_tests/distributed/test_cp_utils.py b/tests/unit_tests/distributed/test_cp_utils.py index b69f808ea..6f66f1fa3 100644 --- a/tests/unit_tests/distributed/test_cp_utils.py +++ b/tests/unit_tests/distributed/test_cp_utils.py @@ -55,6 +55,7 @@ def __init__(self, cp_size: int, tp_size: int): self["tp"] = _DummySubMesh(tp_size) self.mesh_dim_names = ["cp", "tp"] + def test_build_position_ids_adds_missing(): """If ``position_ids`` is absent it should be generated correctly.""" batch: dict[str, Any] = {"input_ids": torch.arange(6).view(1, -1)} @@ -81,6 +82,7 @@ def test_build_position_ids_does_not_override_existing(): _cu._build_position_ids(batch, torch.device("cpu")) assert torch.equal(batch["position_ids"], original_pos), "position_ids should not be modified" + def test_make_cp_batch_and_ctx_no_mesh(): """When *no* device mesh is provided the call should be a no-op.""" input_ids = torch.tensor([[1, 2, 3]]) @@ -112,6 +114,7 @@ def test_make_cp_batch_and_ctx_with_cp(monkeypatch): def _fake_create_ctx(**kwargs): # noqa: D401 """Return a sentinel object so we can verify it was passed through.""" return dummy_cp_ctx + monkeypatch.setattr(_cu, "create_context_parallel_ctx", _fake_create_ctx) def _fake_get_train_ctx(enable_loss_parallel, enable_compiled_autograd, cp_ctx): # noqa: D401 @@ -142,6 +145,148 @@ def _fake_get_train_ctx(enable_loss_parallel, enable_compiled_autograd, cp_ctx): assert new_batch is batch +def test_make_cp_batch_and_ctx_includes_padding_mask(monkeypatch): + """Verify that padding_mask is included in CP buffers when present in batch.""" + + captured_kwargs = {} + + def _fake_create_ctx(**kwargs): + captured_kwargs.update(kwargs) + return object() + + monkeypatch.setattr(_cu, "create_context_parallel_ctx", _fake_create_ctx) + + def _fake_get_train_ctx(enable_loss_parallel, enable_compiled_autograd, cp_ctx): + return "dummy_train_ctx" + + monkeypatch.setattr(_cu, "get_train_context", _fake_get_train_ctx) + + device_mesh = _DummyDeviceMesh(cp_size=2, tp_size=1) + padding_mask = torch.tensor([[True, False, True]]) + batch = { + "input_ids": torch.tensor([[10, 20, 30]]), + "labels": torch.tensor([[10, 20, 30]]), + "padding_mask": padding_mask, + } + + _cu.make_cp_batch_and_ctx(device_mesh, batch, loss_mask=None) + + # padding_mask should be in cp_buffers + assert any( + t is padding_mask for t in captured_kwargs["cp_buffers"] + ), "padding_mask must be included in cp_buffers" + assert padding_mask in captured_kwargs["cp_no_restore_buffers"] + + +def test_make_cp_batch_and_ctx_pops_attention_mask_when_cp_enabled(monkeypatch): + """When CP is enabled, attention_mask should be removed from the batch.""" + + dummy_cp_ctx = object() + monkeypatch.setattr(_cu, "create_context_parallel_ctx", lambda **kwargs: dummy_cp_ctx) + monkeypatch.setattr( + _cu, + "get_train_context", + lambda enable_loss_parallel, enable_compiled_autograd, cp_ctx: "dummy_train_ctx", + ) + + device_mesh = _DummyDeviceMesh(cp_size=2, tp_size=1) + batch = { + "input_ids": torch.tensor([[1, 2, 3]]), + "labels": torch.tensor([[1, 2, 3]]), + "attention_mask": torch.ones(1, 3, dtype=torch.long), + } + + _ctx, new_batch = _cu.make_cp_batch_and_ctx(device_mesh, batch) + + assert "attention_mask" not in new_batch, "attention_mask should be removed when CP > 1" + + +# ============================================================================ +# Tests for attach_context_parallel_hooks +# ============================================================================ + + +class _FakeSelfAttn(torch.nn.Module): + """Minimal module that records the kwargs it receives.""" + + def forward(self, hidden_states, **kwargs): + self.last_kwargs = kwargs + return hidden_states + + +class _FakeTransformerBlock(torch.nn.Module): + """A toy model with a ``self_attn`` sub-module to test hook attachment.""" + + def __init__(self): + super().__init__() + self.self_attn = _FakeSelfAttn() + + +class _FakeModel(torch.nn.Module): + """Two-layer model with ``self_attn`` sub-modules.""" + + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList([_FakeTransformerBlock(), _FakeTransformerBlock()]) + + +def test_attach_context_parallel_hooks_registers_on_self_attn(): + """Hooks should be registered on every module whose name ends with 'self_attn'.""" + model = _FakeModel() + + # Count hooks before + hooks_before = { + name: len(mod._forward_pre_hooks) for name, mod in model.named_modules() if name.endswith("self_attn") + } + + _cu.attach_context_parallel_hooks(model) + + for name, mod in model.named_modules(): + if name.endswith("self_attn"): + assert len(mod._forward_pre_hooks) == hooks_before[name] + 1 + + +def test_attach_context_parallel_hooks_strips_attention_mask(): + """The hook should replace attention_mask with None and set is_causal=True.""" + model = _FakeModel() + _cu.attach_context_parallel_hooks(model) + + dummy_input = torch.randn(1, 4, 8) + attn_mask = torch.ones(1, 1, 4, 4) + + model.layers[0].self_attn(dummy_input, attention_mask=attn_mask) + + kwargs = model.layers[0].self_attn.last_kwargs + assert kwargs["attention_mask"] is None, "attention_mask should be set to None by the hook" + assert kwargs["is_causal"] is True, "is_causal should be set to True by the hook" + + +def test_attach_context_parallel_hooks_no_mask_passthrough(): + """When no attention_mask kwarg is passed, the hook should be a no-op.""" + model = _FakeModel() + _cu.attach_context_parallel_hooks(model) + + dummy_input = torch.randn(1, 4, 8) + model.layers[0].self_attn(dummy_input, some_other_kwarg=42) + + kwargs = model.layers[0].self_attn.last_kwargs + assert "attention_mask" not in kwargs + assert "is_causal" not in kwargs + assert kwargs["some_other_kwarg"] == 42 + + +def test_attach_context_parallel_hooks_skips_non_self_attn(): + """Modules not ending with 'self_attn' should have no hooks added.""" + model = _FakeModel() + _cu.attach_context_parallel_hooks(model) + + # The top-level model and the layers list should not get hooks + assert len(model._forward_pre_hooks) == 0 + assert len(model.layers._forward_pre_hooks) == 0 + for layer in model.layers: + assert len(layer._forward_pre_hooks) == 0 + + # ============================================================================ # Tests for make_cp_batch_for_te # ============================================================================ @@ -183,7 +328,8 @@ def thd_get_partitioned_indices(cu_seqlens_padded, total_tokens, cp_size, cp_ran # Mock at the module level where it's imported import sys - sys.modules['transformer_engine_torch'] = MockTex + + sys.modules["transformer_engine_torch"] = MockTex monkeypatch.setattr(torch.distributed, "get_rank", mock_get_rank) @@ -208,6 +354,38 @@ def thd_get_partitioned_indices(cu_seqlens_padded, total_tokens, cp_size, cp_ran assert result["cu_seqlens"].dtype == torch.int32 +def test_shard_thd_chunk_skips_missing_padding_mask(monkeypatch): + """Test that _shard_thd_chunk_for_te handles missing padding_mask gracefully.""" + cp_mesh = _DummySubMesh(size=2) + + def mock_get_rank(group=None): + return 0 + + class MockTex: + @staticmethod + def thd_get_partitioned_indices(cu_seqlens_padded, total_tokens, cp_size, cp_rank): + return torch.arange(total_tokens) + + import sys + sys.modules['transformer_engine_torch'] = MockTex + + monkeypatch.setattr(torch.distributed, "get_rank", mock_get_rank) + + # Batch without padding_mask — should not raise KeyError + batch = { + "input_ids": torch.tensor([1, 2, 3, 4]), + "labels": torch.tensor([10, 20, 30, 40]), + "position_ids": torch.tensor([0, 1, 2, 3]), + "cu_seqlens": torch.tensor([0, 4], dtype=torch.int32), + "cu_seqlens_padded": torch.tensor([0, 4], dtype=torch.int32), + } + + result = _cu._shard_thd_chunk_for_te(batch, cp_mesh, "thd", -1000, 0) + + assert "input_ids" in result + assert "attention_mask" not in result + + def test_make_cp_batch_for_te_unsupported_format(): """Test that unsupported qvk_format raises ValueError.""" cp_mesh = _DummySubMesh(size=2) diff --git a/tests/unit_tests/loggers/test_comet_utils.py b/tests/unit_tests/loggers/test_comet_utils.py new file mode 100644 index 000000000..6b6e728b5 --- /dev/null +++ b/tests/unit_tests/loggers/test_comet_utils.py @@ -0,0 +1,361 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import types + +import pytest +import torch + + +def _install_fake_comet_ml(): + """ + Install a minimal stub comet_ml package into sys.modules capturing calls. + """ + comet_ml = types.ModuleType("comet_ml") + + calls = { + "experiment_init": [], + "log_parameters": [], + "log_metrics": [], + "set_name": [], + "add_tags": [], + "end": 0, + } + + class _FakeExperiment: + url = "https://www.comet.com/test/fake-experiment" + + def __init__(self, **kwargs): + calls["experiment_init"].append(kwargs) + + def log_parameters(self, params): + calls["log_parameters"].append(params) + + def log_metrics(self, metrics, step=None): + calls["log_metrics"].append((metrics, step)) + + def set_name(self, name): + calls["set_name"].append(name) + + def add_tags(self, tags): + calls["add_tags"].append(tags) + + def end(self): + calls["end"] += 1 + + comet_ml.Experiment = _FakeExperiment + sys.modules["comet_ml"] = comet_ml + return calls + + +@pytest.fixture(autouse=True) +def _clean_sys_modules(): + yield + for name in list(sys.modules): + if name.startswith("comet") or "comet_utils" in name: + del sys.modules[name] + + +def test_build_comet_creates_experiment_with_config(monkeypatch): + calls = _install_fake_comet_ml() + + import torch.distributed as dist + + monkeypatch.setattr(dist, "is_initialized", lambda: True, raising=False) + monkeypatch.setattr(dist, "get_rank", lambda: 0, raising=False) + + from nemo_automodel.components.loggers.comet_utils import build_comet + + class CometCfg: + def __init__(self): + self._data = { + "project_name": "test-project", + "workspace": "test-workspace", + "experiment_name": "test-run", + "api_key": None, + "tags": ["finetune", "llama"], + "auto_metric_logging": False, + } + + def get(self, key, default=None): + return self._data.get(key, default) + + class ModelCfg: + pretrained_model_name_or_path = "org/my-model" + + class Cfg: + def __init__(self): + self.comet = CometCfg() + self.model = ModelCfg() + + def get(self, key, default=None): + return getattr(self, key, default) + + def to_dict(self): + return {"model": "org/my-model"} + + cfg = Cfg() + logger = build_comet(cfg) + assert logger is not None + + assert calls["experiment_init"], "comet_ml.Experiment should have been called" + init_kwargs = calls["experiment_init"][-1] + assert init_kwargs["project_name"] == "test-project" + assert init_kwargs["workspace"] == "test-workspace" + assert init_kwargs["auto_metric_logging"] is False + + assert calls["set_name"] == ["test-run"] + assert calls["add_tags"], "add_tags should have been called" + tags = calls["add_tags"][-1] + assert "finetune" in tags + assert "llama" in tags + assert "model:org/my-model" in tags + + +def test_build_comet_auto_generates_experiment_name(monkeypatch): + calls = _install_fake_comet_ml() + + import torch.distributed as dist + + monkeypatch.setattr(dist, "is_initialized", lambda: True, raising=False) + monkeypatch.setattr(dist, "get_rank", lambda: 0, raising=False) + + from nemo_automodel.components.loggers.comet_utils import build_comet + + class CometCfg: + def __init__(self): + self._data = {"project_name": "test-project"} + + def get(self, key, default=None): + return self._data.get(key, default) + + class ModelCfg: + pretrained_model_name_or_path = "org/my-model" + + class Cfg: + def __init__(self): + self.comet = CometCfg() + self.model = ModelCfg() + + def get(self, key, default=None): + return getattr(self, key, default) + + build_comet(Cfg()) + assert calls["set_name"] == ["org_my-model"] + + +def test_build_comet_raises_without_project_name(monkeypatch): + _install_fake_comet_ml() + + import torch.distributed as dist + + monkeypatch.setattr(dist, "is_initialized", lambda: True, raising=False) + monkeypatch.setattr(dist, "get_rank", lambda: 0, raising=False) + + from nemo_automodel.components.loggers.comet_utils import build_comet + + class CometCfg: + def __init__(self): + self._data = {"workspace": "test"} + + def get(self, key, default=None): + return self._data.get(key, default) + + class Cfg: + def __init__(self): + self.comet = CometCfg() + + def get(self, key, default=None): + return getattr(self, key, default) + + with pytest.raises(ValueError, match="comet.project_name is required"): + build_comet(Cfg()) + + +def test_build_comet_raises_without_config(monkeypatch): + _install_fake_comet_ml() + + import torch.distributed as dist + + monkeypatch.setattr(dist, "is_initialized", lambda: True, raising=False) + monkeypatch.setattr(dist, "get_rank", lambda: 0, raising=False) + + from nemo_automodel.components.loggers.comet_utils import build_comet + + class Cfg: + def get(self, key, default=None): + return default + + with pytest.raises(ValueError, match="Comet configuration not found"): + build_comet(Cfg()) + + +def test_log_params_delegates_to_experiment(monkeypatch): + calls = _install_fake_comet_ml() + import torch.distributed as dist + + monkeypatch.setattr(dist, "is_initialized", lambda: True, raising=False) + monkeypatch.setattr(dist, "get_rank", lambda: 0, raising=False) + + from nemo_automodel.components.loggers.comet_utils import CometLogger + + logger = CometLogger(project_name="p") + logger.log_params({"lr": 0.001, "batch_size": 8}) + + assert calls["log_parameters"], "experiment.log_parameters should have been called" + params = calls["log_parameters"][-1] + assert params["lr"] == 0.001 + assert params["batch_size"] == 8 + + +def test_log_metrics_converts_types_and_uses_step(monkeypatch): + calls = _install_fake_comet_ml() + import torch.distributed as dist + + monkeypatch.setattr(dist, "is_initialized", lambda: True, raising=False) + monkeypatch.setattr(dist, "get_rank", lambda: 0, raising=False) + + from nemo_automodel.components.loggers.comet_utils import CometLogger + + logger = CometLogger(project_name="p") + metrics = { + "int_val": 3, + "float_val": 2.5, + "tensor_scalar": torch.tensor(4.0), + "tensor_vec": torch.tensor([1.0, 3.0]), + "skip_obj": object(), + } + logger.log_metrics(metrics, step=5) + + assert calls["log_metrics"], "experiment.log_metrics should have been called" + logged_metrics, step = calls["log_metrics"][-1] + assert step == 5 + assert isinstance(logged_metrics["int_val"], float) and logged_metrics["int_val"] == 3.0 + assert isinstance(logged_metrics["float_val"], float) and logged_metrics["float_val"] == 2.5 + assert isinstance(logged_metrics["tensor_scalar"], float) and logged_metrics["tensor_scalar"] == 4.0 + assert isinstance(logged_metrics["tensor_vec"], float), "tensor vectors should be averaged to float" + assert "skip_obj" not in logged_metrics + + +def test_log_metrics_without_step(monkeypatch): + calls = _install_fake_comet_ml() + import torch.distributed as dist + + monkeypatch.setattr(dist, "is_initialized", lambda: True, raising=False) + monkeypatch.setattr(dist, "get_rank", lambda: 0, raising=False) + + from nemo_automodel.components.loggers.comet_utils import CometLogger + + logger = CometLogger(project_name="p") + logger.log_metrics({"loss": 0.5}) + + logged_metrics, step = calls["log_metrics"][-1] + assert step is None + assert logged_metrics["loss"] == 0.5 + + +def test_rank_guard_blocks_non_rank_zero(monkeypatch): + calls = _install_fake_comet_ml() + import torch.distributed as dist + + monkeypatch.setattr(dist, "is_initialized", lambda: True, raising=False) + monkeypatch.setattr(dist, "get_rank", lambda: 0, raising=False) + + from nemo_automodel.components.loggers.comet_utils import CometLogger + + logger = CometLogger(project_name="p") + # Switch to non-zero rank -> calls should NO-OP + monkeypatch.setattr(dist, "get_rank", lambda: 1, raising=False) + logger.log_metrics({"a": 1.0}, step=1) + assert not calls["log_metrics"], "log_metrics should not be called on non-main rank" + + logger.log_params({"x": 1}) + assert not calls["log_parameters"], "log_parameters should not be called on non-main rank" + + +def test_experiment_none_guard(monkeypatch): + calls = _install_fake_comet_ml() + import torch.distributed as dist + + monkeypatch.setattr(dist, "is_initialized", lambda: True, raising=False) + monkeypatch.setattr(dist, "get_rank", lambda: 0, raising=False) + + from nemo_automodel.components.loggers.comet_utils import CometLogger + + logger = CometLogger(project_name="p") + logger.experiment = None + + logger.log_metrics({"a": 1.0}, step=1) + assert not calls["log_metrics"], "log_metrics should not be called when experiment is None" + + logger.log_params({"x": 1}) + assert not calls["log_parameters"], "log_params should not be called when experiment is None" + + +def test_end_calls_experiment_end(monkeypatch): + calls = _install_fake_comet_ml() + import torch.distributed as dist + + monkeypatch.setattr(dist, "is_initialized", lambda: True, raising=False) + monkeypatch.setattr(dist, "get_rank", lambda: 0, raising=False) + + from nemo_automodel.components.loggers.comet_utils import CometLogger + + logger = CometLogger(project_name="p") + logger.end() + assert calls["end"] == 1 + + +def test_end_noop_when_no_experiment(monkeypatch): + calls = _install_fake_comet_ml() + import torch.distributed as dist + + monkeypatch.setattr(dist, "is_initialized", lambda: True, raising=False) + monkeypatch.setattr(dist, "get_rank", lambda: 0, raising=False) + + from nemo_automodel.components.loggers.comet_utils import CometLogger + + logger = CometLogger(project_name="p") + logger.experiment = None + logger.end() + assert calls["end"] == 0 + + +def test_context_manager_calls_end(monkeypatch): + calls = _install_fake_comet_ml() + import torch.distributed as dist + + monkeypatch.setattr(dist, "is_initialized", lambda: True, raising=False) + monkeypatch.setattr(dist, "get_rank", lambda: 0, raising=False) + + from nemo_automodel.components.loggers.comet_utils import CometLogger + + with CometLogger(project_name="p"): + pass + assert calls["end"] == 1 + + +def test_no_experiment_created_on_non_rank_zero(monkeypatch): + calls = _install_fake_comet_ml() + import torch.distributed as dist + + monkeypatch.setattr(dist, "is_initialized", lambda: True, raising=False) + monkeypatch.setattr(dist, "get_rank", lambda: 1, raising=False) + + from nemo_automodel.components.loggers.comet_utils import CometLogger + + logger = CometLogger(project_name="p") + assert logger.experiment is None + assert not calls["experiment_init"], "Experiment should not be created on non-rank-0" diff --git a/tests/unit_tests/moe/test_parallelizer.py b/tests/unit_tests/moe/test_parallelizer.py index 9a0f2df5b..303bcdf64 100644 --- a/tests/unit_tests/moe/test_parallelizer.py +++ b/tests/unit_tests/moe/test_parallelizer.py @@ -1816,3 +1816,117 @@ def __init__(self): apply_fsdp_mock.assert_called_once() _, kwargs = apply_fsdp_mock.call_args assert kwargs.get("mp_policy") is None + + +# ============================================================================ +# Tests for apply_cp – skip non-TE attention modules instead of asserting +# ============================================================================ + + +class _FakeAttnModule: + """Non-TE attention module (e.g. SDPA).""" + + pass + + +class _FakeSelfAttn: + def __init__(self, attn_module): + self.attn_module = attn_module + + +class _FakeBlockWithAttn: + def __init__(self, attn_module, moe=None): + self.self_attn = _FakeSelfAttn(attn_module) + self.mlp = moe if moe is not None else object() + + +def test_apply_cp_skips_non_te_attention(monkeypatch): + """apply_cp should skip blocks whose attn_module is not DotProductAttention.""" + P = _import_parallelizer_with_stubs(monkeypatch) + + # Stub DotProductAttention in the TE import inside apply_cp + te_attn_stub = types.ModuleType("transformer_engine.pytorch.attention") + + class DotProductAttention: + pass + + te_attn_stub.DotProductAttention = DotProductAttention + monkeypatch.setitem(sys.modules, "transformer_engine", types.ModuleType("transformer_engine")) + monkeypatch.setitem(sys.modules, "transformer_engine.pytorch", types.ModuleType("transformer_engine.pytorch")) + monkeypatch.setitem(sys.modules, "transformer_engine.pytorch.attention", te_attn_stub) + + non_te_attn = _FakeAttnModule() # not a DotProductAttention + block = _FakeBlockWithAttn(non_te_attn) + model = DummyModel([block]) + + cp_mesh = MagicMock() + cp_mesh.get_group.return_value = MagicMock() + + # Stub get_process_group_ranks to avoid real distributed calls + dist_stub = sys.modules["torch.distributed"] + dist_stub.get_process_group_ranks = MagicMock(return_value=[0, 1]) + + # Should not raise — just skip the non-TE block + P.apply_cp(model, cp_mesh) + + +def _setup_te_and_dist_stubs(monkeypatch, DotProductAttention): + """Register TE and torch.distributed stubs needed by apply_cp.""" + te_attn_stub = types.ModuleType("transformer_engine.pytorch.attention") + te_attn_stub.DotProductAttention = DotProductAttention + monkeypatch.setitem(sys.modules, "transformer_engine", types.ModuleType("transformer_engine")) + monkeypatch.setitem(sys.modules, "transformer_engine.pytorch", types.ModuleType("transformer_engine.pytorch")) + monkeypatch.setitem(sys.modules, "transformer_engine.pytorch.attention", te_attn_stub) + + # apply_cp uses torch.distributed.get_process_group_ranks via attribute access + torch_mod = sys.modules["torch"] + dist_stub = sys.modules["torch.distributed"] + dist_stub.get_process_group_ranks = MagicMock(return_value=[0, 1]) + torch_mod.distributed = dist_stub + + +def test_apply_cp_configures_te_attention(monkeypatch): + """apply_cp should call set_context_parallel_group on TE DotProductAttention modules.""" + P = _import_parallelizer_with_stubs(monkeypatch) + + class DotProductAttention: + def __init__(self): + self.set_context_parallel_group = MagicMock() + + _setup_te_and_dist_stubs(monkeypatch, DotProductAttention) + + te_attn = DotProductAttention() + block = _FakeBlockWithAttn(te_attn) + model = DummyModel([block]) + + cp_mesh = MagicMock() + cp_mesh.get_group.return_value = MagicMock() + + P.apply_cp(model, cp_mesh) + + te_attn.set_context_parallel_group.assert_called_once() + + +def test_apply_cp_mixed_te_and_non_te(monkeypatch): + """apply_cp should configure TE blocks and skip non-TE blocks in the same model.""" + P = _import_parallelizer_with_stubs(monkeypatch) + + class DotProductAttention: + def __init__(self): + self.set_context_parallel_group = MagicMock() + + _setup_te_and_dist_stubs(monkeypatch, DotProductAttention) + + te_attn = DotProductAttention() + non_te_attn = _FakeAttnModule() + block_te = _FakeBlockWithAttn(te_attn) + block_non_te = _FakeBlockWithAttn(non_te_attn) + model = DummyModel([block_te, block_non_te]) + + cp_mesh = MagicMock() + cp_mesh.get_group.return_value = MagicMock() + + P.apply_cp(model, cp_mesh) + + # TE block configured, non-TE block skipped (no error) + te_attn.set_context_parallel_group.assert_called_once() diff --git a/tests/unit_tests/recipes/test_train_ft.py b/tests/unit_tests/recipes/test_train_ft.py index 2f546ea09..5c2814b80 100644 --- a/tests/unit_tests/recipes/test_train_ft.py +++ b/tests/unit_tests/recipes/test_train_ft.py @@ -37,6 +37,7 @@ build_optimizer, build_validation_dataloader, compute_trust_remote_code_from_model, + resolve_sdpa_method, ) @@ -142,8 +143,10 @@ def test_build_validation_dataloader_no_validation_keys(): assert result == {} mock_build.assert_not_called() + class DummyLinear(nn.Module): """Simple linear layer for testing""" + def __init__(self, in_features, out_features): super().__init__() self.weight = nn.Parameter(torch.randn(out_features, in_features)) @@ -153,6 +156,7 @@ def __init__(self, in_features, out_features): class DummyModel(nn.Module): """Simple model for testing PEFT + PP""" + def __init__(self): super().__init__() self.layer1 = DummyLinear(10, 10) @@ -168,6 +172,7 @@ def forward(self, x): class DummyPeftConfig: """Mock PEFT config""" + def __init__(self): self.use_triton = True self.dim = 8 @@ -177,12 +182,14 @@ def __init__(self): class DummyOptConfig: """Mock optimizer config""" + def instantiate(self, params): return torch.optim.SGD(params, lr=0.01) class DummyModelConfig: """Mock model config""" + def __init__(self): self.pretrained_model_name_or_path = None @@ -204,11 +211,16 @@ def test_peft_with_pipeline_parallelism_enabled(caplog): model = DummyModel() mock_autopipeline = MagicMock() - with patch('nemo_automodel._transformers.infrastructure.apply_lora_to_linear_modules') as mock_apply_lora: + with patch("nemo_automodel._transformers.infrastructure.apply_lora_to_linear_modules") as mock_apply_lora: with caplog.at_level(logging.INFO): _apply_peft_and_lower_precision( - model, tp_size=1, autopipeline=mock_autopipeline, - peft_config=cfg_peft, quantization_config=None, fp8_config=None, qat_quantizer=None, + model, + tp_size=1, + autopipeline=mock_autopipeline, + peft_config=cfg_peft, + quantization_config=None, + fp8_config=None, + qat_quantizer=None, ) assert mock_apply_lora.called, "apply_lora_to_linear_modules should be called" @@ -226,17 +238,17 @@ def test_peft_without_pipeline_parallelism(caplog): cfg_peft = DummyPeftConfig() # Mock the apply_lora_to_linear_modules function (now inside apply_model_infrastructure) - with patch('nemo_automodel._transformers.infrastructure.apply_lora_to_linear_modules') as mock_apply_lora: - with patch('nemo_automodel._transformers.infrastructure.print_trainable_parameters'): - with patch('nemo_automodel.recipes.llm.train_ft._supports_logits_to_keep', return_value=True): - with patch('nemo_automodel._transformers.infrastructure._supports_logits_to_keep', return_value=True): - with patch('nemo_automodel._transformers.auto_model._verify_sdpa_support'): - with patch('nemo_automodel._transformers.infrastructure._shard_ep_fsdp') as mock_shard: + with patch("nemo_automodel._transformers.infrastructure.apply_lora_to_linear_modules") as mock_apply_lora: + with patch("nemo_automodel._transformers.infrastructure.print_trainable_parameters"): + with patch("nemo_automodel.recipes.llm.train_ft._supports_logits_to_keep", return_value=True): + with patch("nemo_automodel._transformers.infrastructure._supports_logits_to_keep", return_value=True): + with patch("nemo_automodel._transformers.auto_model._verify_sdpa_support"): + with patch("nemo_automodel._transformers.infrastructure._shard_ep_fsdp") as mock_shard: # Return a DummyModel with lora_dummy_param so freeze doesn't remove all trainable params sharded_model = DummyModel() sharded_model.register_parameter( "lora_dummy_param", - nn.Parameter(torch.tensor(1.0, device=torch.device("cuda")), requires_grad=True) + nn.Parameter(torch.tensor(1.0, device=torch.device("cuda")), requires_grad=True), ) mock_shard.return_value = sharded_model with caplog.at_level(logging.INFO): @@ -263,11 +275,16 @@ def test_peft_with_tp_disables_triton(caplog): cfg_peft = DummyPeftConfig() model = DummyModel() - with patch('nemo_automodel._transformers.infrastructure.apply_lora_to_linear_modules'): + with patch("nemo_automodel._transformers.infrastructure.apply_lora_to_linear_modules"): with caplog.at_level(logging.INFO): _apply_peft_and_lower_precision( - model, tp_size=2, autopipeline=None, - peft_config=cfg_peft, quantization_config=None, fp8_config=None, qat_quantizer=None, + model, + tp_size=2, + autopipeline=None, + peft_config=cfg_peft, + quantization_config=None, + fp8_config=None, + qat_quantizer=None, ) assert cfg_peft.use_triton == False, "use_triton should be disabled for TP" @@ -328,12 +345,15 @@ def test_build_dataloader_iterable_shard_and_shuffle_removed_from_cfg(monkeypatc class _FlagCM(AbstractContextManager): """Simple context manager that flips a flag on enter/exit.""" + def __init__(self, flags, key): self.flags = flags self.key = key + def __enter__(self): self.flags[self.key] = True return self + def __exit__(self, exc_type, exc, tb): return False @@ -470,18 +490,37 @@ def _patch_setup_minimals(monkeypatch, patch_fn): "nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._setup_qat", lambda *a, **k: (None, None, None), ) - monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction.load_checkpoint", lambda *a, **k: None) - monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._log_step_scheduler_details", lambda *a, **k: None) + monkeypatch.setattr( + "nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction.load_checkpoint", + lambda *a, **k: None, + ) + monkeypatch.setattr( + "nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._log_step_scheduler_details", + lambda *a, **k: None, + ) # Avoid CUDA calls monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.torch.cuda.reset_peak_memory_stats", lambda: None) # Make group/rank helpers trivial - monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._get_dp_rank", lambda self, include_cp=False: 0) - monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._get_dp_group_size", lambda self, include_cp=False: 1) - monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._get_cp_group_size", lambda self: 1) - monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._get_tp_rank", lambda self: 0) - monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._get_pp_rank", lambda self: 0) + monkeypatch.setattr( + "nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._get_dp_rank", + lambda self, include_cp=False: 0, + ) + monkeypatch.setattr( + "nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._get_dp_group_size", + lambda self, include_cp=False: 1, + ) + monkeypatch.setattr( + "nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._get_cp_group_size", + lambda self: 1, + ) + monkeypatch.setattr( + "nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._get_tp_rank", lambda self: 0 + ) + monkeypatch.setattr( + "nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction._get_pp_rank", lambda self: 0 + ) # Provide a dummy autonvtx module to satisfy import and capture patch calls dummy_autonvtx = types.ModuleType("nemo_automodel.autonvtx") @@ -553,7 +592,9 @@ class DummyAutoPipeline(SimpleNamespace): parts = [DummyModel(), DummyModel()] def _build_model_stub(*args, **kwargs): - return DummyAutoPipeline(parts=parts, info=SimpleNamespace(has_last_stage=False, has_first_stage=False, schedule=None)) + return DummyAutoPipeline( + parts=parts, info=SimpleNamespace(has_last_stage=False, has_first_stage=False, schedule=None) + ) def _build_optimizer_stub(*args, **kwargs): dummy_opt = SimpleNamespace(param_groups=[{"lr": 0.01}], step=lambda: None, zero_grad=lambda: None) @@ -1041,17 +1082,16 @@ def get(self, key, default=None): @requires_cuda def test_build_model_state_dict_keys_uses_adapter(caplog): - """Test that state_dict_keys are transformed using _maybe_adapt_state_dict_to_hf when adapter is present. - """ + """Test that state_dict_keys are transformed using _maybe_adapt_state_dict_to_hf when adapter is present.""" cfg_model = DummyModelConfigWithAdapter() cfg_opt = DummyOptConfig() cfg_peft = None - with patch('nemo_automodel.recipes.llm.train_ft._supports_logits_to_keep', return_value=True): - with patch('nemo_automodel._transformers.infrastructure._supports_logits_to_keep', return_value=True): - with patch('nemo_automodel._transformers.auto_model._verify_sdpa_support'): - with patch('nemo_automodel._transformers.infrastructure.print_trainable_parameters'): + with patch("nemo_automodel.recipes.llm.train_ft._supports_logits_to_keep", return_value=True): + with patch("nemo_automodel._transformers.infrastructure._supports_logits_to_keep", return_value=True): + with patch("nemo_automodel._transformers.auto_model._verify_sdpa_support"): + with patch("nemo_automodel._transformers.infrastructure.print_trainable_parameters"): model = build_model( cfg_model=cfg_model, cfg_peft=cfg_peft, @@ -1072,10 +1112,10 @@ def test_build_model_state_dict_keys_without_adapter(): cfg_opt = DummyOptConfig() cfg_peft = None - with patch('nemo_automodel.recipes.llm.train_ft._supports_logits_to_keep', return_value=True): - with patch('nemo_automodel._transformers.infrastructure._supports_logits_to_keep', return_value=True): - with patch('nemo_automodel._transformers.auto_model._verify_sdpa_support'): - with patch('nemo_automodel._transformers.infrastructure.print_trainable_parameters'): + with patch("nemo_automodel.recipes.llm.train_ft._supports_logits_to_keep", return_value=True): + with patch("nemo_automodel._transformers.infrastructure._supports_logits_to_keep", return_value=True): + with patch("nemo_automodel._transformers.auto_model._verify_sdpa_support"): + with patch("nemo_automodel._transformers.infrastructure.print_trainable_parameters"): model = build_model( cfg_model=cfg_model, cfg_peft=cfg_peft, @@ -1111,10 +1151,10 @@ def get(self, key, default=None): cfg_model = DummyQuantizedModelConfig() - with patch('nemo_automodel.recipes.llm.train_ft._supports_logits_to_keep', return_value=True): - with patch('nemo_automodel._transformers.infrastructure._supports_logits_to_keep', return_value=True): - with patch('nemo_automodel._transformers.auto_model._verify_sdpa_support'): - with patch('nemo_automodel._transformers.infrastructure.print_trainable_parameters'): + with patch("nemo_automodel.recipes.llm.train_ft._supports_logits_to_keep", return_value=True): + with patch("nemo_automodel._transformers.infrastructure._supports_logits_to_keep", return_value=True): + with patch("nemo_automodel._transformers.auto_model._verify_sdpa_support"): + with patch("nemo_automodel._transformers.infrastructure.print_trainable_parameters"): model = build_model( cfg_model=cfg_model, cfg_peft=cfg_peft, @@ -1135,10 +1175,10 @@ def test_build_model_without_quant_config(): cfg_opt = DummyOptConfig() cfg_peft = None - with patch('nemo_automodel.recipes.llm.train_ft._supports_logits_to_keep', return_value=True): - with patch('nemo_automodel._transformers.infrastructure._supports_logits_to_keep', return_value=True): - with patch('nemo_automodel._transformers.auto_model._verify_sdpa_support'): - with patch('nemo_automodel._transformers.infrastructure.print_trainable_parameters'): + with patch("nemo_automodel.recipes.llm.train_ft._supports_logits_to_keep", return_value=True): + with patch("nemo_automodel._transformers.infrastructure._supports_logits_to_keep", return_value=True): + with patch("nemo_automodel._transformers.auto_model._verify_sdpa_support"): + with patch("nemo_automodel._transformers.infrastructure.print_trainable_parameters"): model = build_model( cfg_model=cfg_model, cfg_peft=cfg_peft, @@ -1170,10 +1210,10 @@ def test_build_optimizer_disables_foreach_with_tp(): mock_mesh.mesh_dim_names = ("dp", "tp") mock_mesh.__getitem__ = lambda self, key: mock_tp if key == "tp" else MagicMock() - with patch('nemo_automodel.recipes.llm.train_ft._supports_logits_to_keep', return_value=True): - with patch('nemo_automodel._transformers.infrastructure._supports_logits_to_keep', return_value=True): - with patch('nemo_automodel._transformers.auto_model._verify_sdpa_support'): - with patch('nemo_automodel._transformers.infrastructure.print_trainable_parameters'): + with patch("nemo_automodel.recipes.llm.train_ft._supports_logits_to_keep", return_value=True): + with patch("nemo_automodel._transformers.infrastructure._supports_logits_to_keep", return_value=True): + with patch("nemo_automodel._transformers.auto_model._verify_sdpa_support"): + with patch("nemo_automodel._transformers.infrastructure.print_trainable_parameters"): model = build_model( cfg_model=cfg_model, cfg_peft=None, @@ -1192,10 +1232,10 @@ def test_build_model_and_optimizer_return_values(): cfg_model = DummyModelConfig() cfg_opt = DummyOptConfig() - with patch('nemo_automodel.recipes.llm.train_ft._supports_logits_to_keep', return_value=True): - with patch('nemo_automodel._transformers.infrastructure._supports_logits_to_keep', return_value=True): - with patch('nemo_automodel._transformers.auto_model._verify_sdpa_support'): - with patch('nemo_automodel._transformers.infrastructure.print_trainable_parameters'): + with patch("nemo_automodel.recipes.llm.train_ft._supports_logits_to_keep", return_value=True): + with patch("nemo_automodel._transformers.infrastructure._supports_logits_to_keep", return_value=True): + with patch("nemo_automodel._transformers.auto_model._verify_sdpa_support"): + with patch("nemo_automodel._transformers.infrastructure.print_trainable_parameters"): model = build_model( cfg_model=cfg_model, cfg_peft=None, @@ -1450,14 +1490,17 @@ def mock_add_masks(batch, model_config=None): assert call_order == ["base", "masks"] -@pytest.mark.parametrize("cfg_attrs,expected", [ - # String config - ({"config": "org/model-name"}, "org/model-name"), - # Direct pretrained_model_name_or_path - ({"pretrained_model_name_or_path": "direct/model"}, "direct/model"), - # Not found - returns None - ({}, None), -]) +@pytest.mark.parametrize( + "cfg_attrs,expected", + [ + # String config + ({"config": "org/model-name"}, "org/model-name"), + # Direct pretrained_model_name_or_path + ({"pretrained_model_name_or_path": "direct/model"}, "direct/model"), + # Not found - returns None + ({}, None), + ], +) def test_get_model_name(cfg_attrs, expected): """Test _get_model_name extracts model name from various config structures.""" from nemo_automodel.recipes.llm.train_ft import _get_model_name @@ -1565,9 +1608,7 @@ def test_log_moe_metrics_passes_top_k_from_config(): def test_log_moe_metrics_detailed_mode(): """Detailed mode should call compute_detailed_metrics (includes per-layer keys).""" loads = _make_moe_layer_loads([[100.0, 200.0], [300.0, 400.0]]) - trainer = _make_trainer_for_moe( - {"enabled": True, "mode": "detailed", "top_k_experts": 2}, layer_loads=loads - ) + trainer = _make_trainer_for_moe({"enabled": True, "mode": "detailed", "top_k_experts": 2}, layer_loads=loads) log_fn = MagicMock() trainer._log_moe_metrics(step=10, wandb_log_fn=log_fn) @@ -1610,7 +1651,6 @@ def teardown_method(self): MoEAuxLossAutoScaler.main_loss_backward_scale = None def _make_recipe(self, monkeypatch, pp_enabled, dp_group_size=4): - from nemo_automodel.components.config.loader import ConfigNode cfg = ConfigNode( @@ -1644,8 +1684,11 @@ def _make_recipe(self, monkeypatch, pp_enabled, dp_group_size=4): object.__setattr__(recipe, "pp_enabled", pp_enabled) object.__setattr__(recipe, "te_fp8", None) object.__setattr__(recipe, "model_parts", [nn.Linear(4, 4)]) - object.__setattr__(recipe, "optimizer", [SimpleNamespace(step=lambda: None, zero_grad=lambda: None, - param_groups=[{"lr": 0.01}])]) + object.__setattr__( + recipe, + "optimizer", + [SimpleNamespace(step=lambda: None, zero_grad=lambda: None, param_groups=[{"lr": 0.01}])], + ) object.__setattr__(recipe, "lr_schedulers", []) object.__setattr__(recipe, "step_scheduler", SimpleNamespace(step=1, epoch=0)) @@ -1657,7 +1700,11 @@ def _make_recipe(self, monkeypatch, pp_enabled, dp_group_size=4): object.__setattr__(recipe, "device_mesh", SimpleNamespace(mesh=mock_mesh)) object.__setattr__(recipe, "tokenizer", SimpleNamespace(pad_token_id=0)) - monkeypatch.setattr(recipe, "_dp_allreduce", lambda val, include_cp=False: val if isinstance(val, torch.Tensor) else torch.tensor(val)) + monkeypatch.setattr( + recipe, + "_dp_allreduce", + lambda val, include_cp=False: val if isinstance(val, torch.Tensor) else torch.tensor(val), + ) monkeypatch.setattr(recipe, "_get_dp_group_size", lambda include_cp=False: dp_group_size) monkeypatch.setattr(recipe, "_get_cp_group_size", lambda: 1) @@ -1669,15 +1716,9 @@ def mock_forward_backward_step(idx, batch, *, loss_buffer, num_label_tokens, num "nemo_automodel.recipes.llm.train_ft.scale_grads_and_clip_grad_norm", lambda *a, **k: torch.tensor(1.0), ) - monkeypatch.setattr( - "nemo_automodel.recipes.llm.train_ft.prepare_for_grad_accumulation", lambda *a, **k: None - ) - monkeypatch.setattr( - "nemo_automodel.recipes.llm.train_ft.prepare_for_final_backward", lambda *a, **k: None - ) - monkeypatch.setattr( - "nemo_automodel.recipes.llm.train_ft.prepare_after_first_microbatch", lambda *a, **k: None - ) + monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.prepare_for_grad_accumulation", lambda *a, **k: None) + monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.prepare_for_final_backward", lambda *a, **k: None) + monkeypatch.setattr("nemo_automodel.recipes.llm.train_ft.prepare_after_first_microbatch", lambda *a, **k: None) object.__setattr__(recipe, "checkpointer", SimpleNamespace(maybe_wait_for_staging=lambda: None)) object.__setattr__(recipe, "lr_scheduler", None) object.__setattr__(recipe, "timestamp", 0.0) @@ -1786,3 +1827,58 @@ def test_rope_fusion_stays_false_when_already_disabled(monkeypatch): trainer.setup() assert cfg.model.backend.rope_fusion is False + + +# ============================================================================ +# Tests for resolve_sdpa_method +# ============================================================================ + + +class TestResolveSdpaMethod: + """Tests for resolve_sdpa_method helper.""" + + def test_explicit_strings_converted_to_backends(self): + from torch.nn.attention import SDPBackend + + result = resolve_sdpa_method(["flash_attention", "math"]) + assert result == [SDPBackend.FLASH_ATTENTION, SDPBackend.MATH] + + def test_case_insensitive(self): + from torch.nn.attention import SDPBackend + + result = resolve_sdpa_method(["Flash_Attention", "EFFICIENT_ATTENTION"]) + assert result == [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] + + def test_invalid_backend_raises(self): + with pytest.raises(ValueError, match="Unknown SDPA backend 'bogus'"): + resolve_sdpa_method(["bogus"]) + + def test_none_with_no_constraints_returns_none(self): + assert resolve_sdpa_method(None) is None + + def test_auto_cp_restricts_backends(self): + from torch.nn.attention import SDPBackend + + mesh = MagicMock() + mesh.mesh_dim_names = ("dp", "cp") + mesh.__getitem__ = lambda self, key: MagicMock(size=lambda: 2) if key == "cp" else MagicMock(size=lambda: 1) + + result = resolve_sdpa_method(None, device_mesh=mesh) + assert result == [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] + + def test_auto_activation_checkpointing_restricts_backends(self): + from torch.nn.attention import SDPBackend + + result = resolve_sdpa_method(None, activation_checkpointing=True) + assert result == [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH] + + def test_explicit_overrides_auto(self): + """When cfg_sdpa_method is provided, auto-selection is bypassed.""" + from torch.nn.attention import SDPBackend + + mesh = MagicMock() + mesh.mesh_dim_names = ("dp", "cp") + mesh.__getitem__ = lambda self, key: MagicMock(size=lambda: 2) if key == "cp" else MagicMock(size=lambda: 1) + + result = resolve_sdpa_method(["math"], device_mesh=mesh, activation_checkpointing=True) + assert result == [SDPBackend.MATH] diff --git a/tests/unit_tests/training/test_train_ft_mlflow_logging.py b/tests/unit_tests/training/test_train_ft_mlflow_logging.py index 212ee4af1..1789559c2 100644 --- a/tests/unit_tests/training/test_train_ft_mlflow_logging.py +++ b/tests/unit_tests/training/test_train_ft_mlflow_logging.py @@ -67,13 +67,26 @@ def test_log_train_metrics_calls_mlflow(monkeypatch): recipe.metric_logger_train = types.SimpleNamespace(log=lambda x: None) mlflow_mock = Mock() recipe.mlflow_logger = types.SimpleNamespace(log_metrics=mlflow_mock) + recipe.comet_logger = None # Avoid cuda calls on environments without GPUs import torch.cuda monkeypatch.setattr(torch.cuda, "reset_peak_memory_stats", lambda: None, raising=False) - log_data = MetricsSample(step=7, epoch=1, metrics={"loss": 1.23, "grad_norm": 0.5, "lr": 1e-3, "mem": 0.1, "tps": 10.0, "tps_per_gpu": 5.0, "num_label_tokens": 42}) + log_data = MetricsSample( + step=7, + epoch=1, + metrics={ + "loss": 1.23, + "grad_norm": 0.5, + "lr": 1e-3, + "mem": 0.1, + "tps": 10.0, + "tps_per_gpu": 5.0, + "num_label_tokens": 42, + }, + ) recipe.log_train_metrics(log_data) mlflow_mock.assert_called_once() @@ -90,13 +103,14 @@ def test_log_val_metrics_calls_mlflow(monkeypatch): recipe.dist_env = types.SimpleNamespace(is_main=True) mlflow_mock = Mock() recipe.mlflow_logger = types.SimpleNamespace(log_metrics=mlflow_mock) + recipe.comet_logger = None # No JSONL logger passed (None) to keep test minimal - log_data = MetricsSample(step=3, epoch=0, metrics={"val_loss": 0.99, "lr": 5e-4, "num_label_tokens": 100, "mem": 0.2}) + log_data = MetricsSample( + step=3, epoch=0, metrics={"val_loss": 0.99, "lr": 5e-4, "num_label_tokens": 100, "mem": 0.2} + ) recipe.log_val_metrics("default", log_data, metric_logger=None) mlflow_mock.assert_called_once() args, kwargs = mlflow_mock.call_args assert isinstance(args[0], dict) and kwargs.get("step") == log_data.step - -