diff --git a/nemo_automodel/_transformers/__init__.py b/nemo_automodel/_transformers/__init__.py index f4c78c06e..cb8a45515 100644 --- a/nemo_automodel/_transformers/__init__.py +++ b/nemo_automodel/_transformers/__init__.py @@ -28,6 +28,7 @@ "NeMoAutoModelForTextToWaveform": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelForTextToWaveform"), "NeMoAutoModelBiencoder": ("nemo_automodel._transformers.auto_model", "NeMoAutoModelBiencoder"), "NeMoAutoTokenizer": ("nemo_automodel._transformers.auto_tokenizer", "NeMoAutoTokenizer"), + "AutoMFU": ("nemo_automodel._transformers.mfu", "AutoMFU"), } __all__ = [ @@ -38,6 +39,7 @@ "NeMoAutoModelForTextToWaveform", "NeMoAutoModelBiencoder", "NeMoAutoTokenizer", + "AutoMFU", ] diff --git a/nemo_automodel/_transformers/mfu.py b/nemo_automodel/_transformers/mfu.py new file mode 100644 index 000000000..4dd169616 --- /dev/null +++ b/nemo_automodel/_transformers/mfu.py @@ -0,0 +1,221 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AutoMFU: Automatic Model FLOPs Utilization calculator. + +Similar interface to HuggingFace AutoModel, this module provides automatic +MFU calculation for various model architectures. +""" + +import logging +from os import PathLike +from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union + +import torch + +if TYPE_CHECKING: + from transformers import PretrainedConfig + +from nemo_automodel.components.utils.flops_utils import ( + calculate_mfu, + get_flops_formula_for_hf_config, +) + +logger = logging.getLogger(__name__) + +# Device theoretical FLOPS (FLOPs/s) adapted from https://github.com/verl-project/verl/blob/main/verl/utils/flops_counter.py#L22-L85 +_DEVICE_FLOPS: Dict[str, float] = { + "CPU": 448e9, + "GB200": 2.5e15, + "B200": 2.25e15, + "MI300X": 1336e12, + "H100": 989e12, + "H800": 989e12, + "H200": 989e12, + "A100": 312e12, + "A800": 312e12, + "L40S": 362.05e12, + "L40": 181.05e12, + "A40": 149.7e12, + "L20": 119.5e12, + "H20": 148e12, + "910B": 354e12, + "Ascend910": 354e12, + "RTX 3070 Ti": 21.75e12, +} + +_UNIT_TO_SCALE = { + "B": 1e9, + "K": 1e3, + "M": 1e6, + "G": 1e9, + "T": 1e12, + "P": 1e15, +} + + +def get_device_flops(unit: str = "T", device_name: Optional[str] = None) -> float: + """Get theoretical device FLOPS in a requested unit. + + Args: + unit: One of ``B/K/M/G/T/P``. Default ``T`` (TFLOPs/s). + device_name: Optional explicit device name for lookup. If ``None``, + the current torch device name is inferred. + + Returns: + Theoretical FLOPS in requested unit. Returns ``float("inf")`` for + unknown devices. + """ + unit = unit.upper() + if unit not in _UNIT_TO_SCALE: + supported = ", ".join(_UNIT_TO_SCALE.keys()) + raise ValueError(f"Unsupported unit '{unit}'. Supported units: {supported}") + + if device_name is None: + if torch.cuda.is_available(): + device_name = torch.cuda.get_device_name(torch.cuda.current_device()) + else: + device_name = "CPU" + + flops = float("inf") + normalized_device = str(device_name).lower() + for key, value in sorted(_DEVICE_FLOPS.items(), key=lambda kv: len(kv[0]), reverse=True): + if key.lower() in normalized_device: + flops = value + break + + return flops / _UNIT_TO_SCALE[unit] + + +class AutoMFU: + """Auto MFU calculator - provides MFU calculation for various model architectures. + + This class provides a HuggingFace AutoModel-like interface for calculating + Model FLOPs Utilization (MFU) during training. + """ + def __init__(self, config: "PretrainedConfig", device: str = "h100"): + """Initialize AutoMFU with a model config. + + Args: + config: HuggingFace PretrainedConfig object + device: Device name (e.g. ``"h100"``) + """ + self.config = config + self.flops_formula = get_flops_formula_for_hf_config(config) + self.reference_mfu = get_device_flops(unit="T", device_name=device) + + @classmethod + def register_device(cls, device: str, peak_tflops: float) -> None: + """Register or override a device peak TFLOPs entry used for MFU calculation.""" + _DEVICE_FLOPS[str(device)] = float(peak_tflops) * 1e12 + + @classmethod + def from_config( + cls, + config_or_path_or_model: Union["PretrainedConfig", str, PathLike[str], object], + device: str = "h100", + **kwargs, + ) -> "AutoMFU": + """Create AutoMFU from a config object, model object, or model path/ID. + + Args: + config_or_path_or_model: Either a PretrainedConfig object, a model object + (the .config attribute will be extracted), or a model ID/local path. + device: Device name (e.g. ``"h100"``) + **kwargs: Additional arguments passed to AutoConfig.from_pretrained + when loading from model ID/path. + + Returns: + AutoMFU instance + """ + config = config_or_path_or_model + if hasattr(config, "config"): + config = config.config + elif isinstance(config, (str, PathLike)): + from transformers import AutoConfig + + config = AutoConfig.from_pretrained(str(config), **kwargs) + return cls(config, device=device) + + @classmethod + def from_pretrained( + cls, + model_id_or_local_path_or_model: Union[str, PathLike[str], object], + device: str = "h100", + **kwargs, + ) -> "AutoMFU": + """Create AutoMFU from model ID, local path, or a model object. + + Args: + model_id_or_local_path_or_model: Model ID (e.g., "meta-llama/llama-3-70b"), + local path, or model object (the .config attribute will be extracted) + device: Device name (e.g. ``"h100"``) + **kwargs: Additional arguments passed to AutoConfig.from_pretrained + + Returns: + AutoMFU instance + """ + return cls.from_config(model_id_or_local_path_or_model, device=device, **kwargs) + + def __call__( + self, + input_ids_or_tensor: Union[torch.Tensor, Tuple[int, int]], + time_delta: float, + world_size: int, + ) -> Optional[float]: + """Calculate MFU percentage. + + Args: + input_ids_or_tensor: Either a tensor (batch_size, seq_len) or + a tuple of (batch_size, seq_len) + time_delta: Time taken for forward/backward pass in seconds + world_size: Number of GPUs used for training + + Returns: + MFU as a percentage, or None if model not supported + """ + if self.flops_formula is None: + return None + + if hasattr(input_ids_or_tensor, "shape"): + batch_size, seq_len = input_ids_or_tensor.shape[:2] + else: + batch_size, seq_len = input_ids_or_tensor + + flops = self.flops_formula(self.config, gbs=batch_size, seq_len=seq_len) + tflops = flops / 1e12 + return calculate_mfu(tflops, world_size, time_delta, reference_mfu=self.reference_mfu) + + def get_flops( + self, + input_ids_or_tensor: Union[torch.Tensor, Tuple[int, int]], + ) -> Optional[float]: + """Calculate FLOPs for given input shape. + + Args: + input_ids_or_tensor: Either a tensor (batch_size, seq_len) or + a tuple of (batch_size, seq_len) + + Returns: + FLOPs as a float, or None if model not supported + """ + if self.flops_formula is None: + return None + + if hasattr(input_ids_or_tensor, "shape"): + batch_size, seq_len = input_ids_or_tensor.shape[:2] + else: + batch_size, seq_len = input_ids_or_tensor + + return self.flops_formula(self.config, gbs=batch_size, seq_len=seq_len) diff --git a/nemo_automodel/recipes/llm/train_ft.py b/nemo_automodel/recipes/llm/train_ft.py index f13e2b44c..9fc6d3abb 100644 --- a/nemo_automodel/recipes/llm/train_ft.py +++ b/nemo_automodel/recipes/llm/train_ft.py @@ -38,6 +38,7 @@ apply_model_infrastructure, instantiate_infrastructure, ) +from nemo_automodel._transformers.mfu import AutoMFU from nemo_automodel._transformers.utils import apply_cache_compatibility_patches from nemo_automodel.components.checkpoint.checkpointing import ( Checkpointer, @@ -79,6 +80,7 @@ from nemo_automodel.components.utils.compile_utils import ( build_compile_config, ) +from nemo_automodel.components.utils.flops_utils import calculate_mfu from nemo_automodel.components.utils.model_utils import ( _supports_logits_to_keep, _supports_seq_lens, @@ -1043,6 +1045,8 @@ def setup(self): for mp in self.model_parts: enable_load_balance_tracking(mp) + self.mfu_calculator = AutoMFU.from_config(self.model_parts[0]) + restore_from = self.cfg.get("checkpoint.restore_from", None) # Initialize JSONL loggers self.metric_logger_train = build_metric_logger( @@ -1411,6 +1415,28 @@ def _run_train_optim_step(self, batches, max_grad_norm: Optional[float] = None): time_delta = t - self.timestamp self.timestamp = t tps = num_tokens_in_batch / time_delta + + mfu = None + if batches: + step_flops = 0.0 + flops_supported = True + for batch in batches: + input_ids = batch.get("input_ids") + if input_ids is None: + flops_supported = False + break + batch_flops = self.mfu_calculator.get_flops(input_ids) + if batch_flops is None: + flops_supported = False + break + step_flops += float(batch_flops) + + if flops_supported: + step_flops = self._dp_allreduce( + torch.tensor(step_flops, dtype=torch.float64, device=self.dist_env.device), include_cp=True + ).item() + mfu = calculate_mfu(step_flops / 1e12, self.dist_env.world_size, time_delta) + reporting_loss = torch.sum(torch.stack(loss_buffer)) reporting_loss = self._dp_allreduce(reporting_loss, include_cp=True) if self.pp_enabled: @@ -1436,6 +1462,7 @@ def _run_train_optim_step(self, batches, max_grad_norm: Optional[float] = None): "mem": torch.cuda.max_memory_allocated() / 1024**3, "tps": tps, "tps_per_gpu": tps / self._get_cp_group_size() / max(self._get_dp_group_size(), 1), + "mfu": mfu, "num_tokens_per_step": num_tokens_in_batch, "num_label_tokens": num_label_tokens, }, diff --git a/nemo_automodel/recipes/llm/train_seq_cls.py b/nemo_automodel/recipes/llm/train_seq_cls.py index dc1b8f17e..c5ecd513e 100644 --- a/nemo_automodel/recipes/llm/train_seq_cls.py +++ b/nemo_automodel/recipes/llm/train_seq_cls.py @@ -21,6 +21,7 @@ import torch import wandb +from nemo_automodel._transformers.mfu import AutoMFU from nemo_automodel._transformers.utils import apply_cache_compatibility_patches from nemo_automodel.components.config._arg_parser import parse_args_and_load_config from nemo_automodel.components.loggers.log_utils import setup_logging @@ -28,6 +29,7 @@ from nemo_automodel.components.loggers.wandb_utils import suppress_wandb_log_messages from nemo_automodel.components.training.rng import StatefulRNG from nemo_automodel.components.training.utils import clip_grad_norm +from nemo_automodel.components.utils.flops_utils import calculate_mfu from nemo_automodel.recipes._dist_setup import setup_distributed from nemo_automodel.recipes.base_recipe import BaseRecipe from nemo_automodel.recipes.llm.train_ft import ( @@ -119,6 +121,7 @@ def setup(self): self.optimizer = build_optimizer(model, self.cfg.optimizer, self.distributed_config, self.device_mesh) self.model_parts = [model] + self.mfu_calculator = AutoMFU.from_config(self.model_parts[0]) self.dataloader, self.tokenizer = build_dataloader( self.cfg.dataset, @@ -275,6 +278,27 @@ def _run_train_optim_step(self, batches): self.timestamp = t tps = num_tokens_in_batch / time_delta + mfu = None + if batches: + step_flops = 0.0 + flops_supported = True + for batch in batches: + input_ids = batch.get("input_ids") + if input_ids is None: + flops_supported = False + break + batch_flops = self.mfu_calculator.get_flops(input_ids) + if batch_flops is None: + flops_supported = False + break + step_flops += float(batch_flops) + + if flops_supported: + step_flops = self._dp_allreduce( + torch.tensor(step_flops, dtype=torch.float64, device=self.dist_env.device), include_cp=True + ).item() + mfu = calculate_mfu(step_flops / 1e12, self.dist_env.world_size, time_delta) + total_loss = torch.sum(torch.stack(losses)) total_loss = self._dp_allreduce(total_loss, include_cp=True).detach() loss = total_loss / len(batches) @@ -290,6 +314,7 @@ def _run_train_optim_step(self, batches): "mem": torch.cuda.max_memory_allocated() / 1024**3, "tps": tps, "tps_per_gpu": tps / self._get_cp_group_size() / max(self._get_dp_group_size(), 1), + "mfu": mfu, }, )