diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 409c726d10c..7a3b826dc0e 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -31,6 +31,7 @@ ) from megatron.core.inference.engines.abstract_engine import AbstractEngine from megatron.core.inference.headers import Headers, UnknownHeaderError +from megatron.core.inference.inference_flops import InferenceFLOPsCalculator from megatron.core.inference.inference_request import ( DynamicInferenceEvent, DynamicInferenceEventType, @@ -226,6 +227,22 @@ def __init__(self, controller: TextGenerationController, context: DynamicInferen self.unified_memory_level = inference_config.unified_memory_level self.cuda_graph_impl = model_config.cuda_graph_impl self.cuda_graph_scope = model_config.cuda_graph_scope + + # Initialize inference FLOPs calculator and GPU peak for MFU reporting. + self.flops_calculator = None + self.gpu_peak_tflops = 0.0 + self.cumulative_inference_flops = 0.0 + self.cumulative_inference_time = 0.0 + try: + from megatron.training.global_vars import get_args + from megatron.training.gpu_peak_flops import get_gpu_peak_tflops + + args = get_args() + self.flops_calculator = InferenceFLOPsCalculator.from_args(args) + self.gpu_peak_tflops = get_gpu_peak_tflops() + except Exception as e: + logging.warning(f"Could not initialize inference FLOPs calculator: {e}") + # Initialize engine. self.reset() @@ -1656,6 +1673,35 @@ async def async_bookkeep( self.socket_for_receiving_requests.send(payload) range_pop() + # Compute inference FLOPs for this step. + step_flops_info = None + if self.flops_calculator is not None: + batch_dims = self.context.batch_dimensions + decode_tokens = batch_dims.decode_req_count if batch_dims else 0 + prefill_reqs = batch_dims.prefill_req_count if batch_dims else 0 + total_tokens = batch_dims.token_count if batch_dims else 0 + prefill_tokens = total_tokens - decode_tokens + + step_flops_info = self.flops_calculator.compute_step_flops( + decode_tokens=decode_tokens, + prefill_tokens=prefill_tokens, + total_tokens=total_tokens, + active_blocks=context_state["total_active_used_blocks"], + active_reqs=context_state["total_request_count"] + - context_state["paused_request_count"], + num_prefill_reqs=prefill_reqs, + ) + self.cumulative_inference_flops += step_flops_info['total_flops'] + self.cumulative_inference_time += step_time + try: + from megatron.training.mfu_tracker import get_mfu_tracker + + get_mfu_tracker().add_inference_flops( + step_flops_info['total_flops'], step_time, tokens=total_tokens + ) + except Exception: + pass + # Log KV cache utilization stats to W&B if context_state["kv_stats"] is not None: # Prepare metrics dictionary with all stats @@ -1668,6 +1714,32 @@ async def async_bookkeep( 'inference/waiting_queue_len': int(len(self.waiting_request_ids)), 'inference/total_requests_dict_size': int(len(self.requests)), } + + batch_dims = self.context.batch_dimensions + total_tokens = batch_dims.token_count if batch_dims else 0 + if step_time > 0 and total_tokens > 0: + metrics['inference/tokens_per_sec_per_gpu'] = float(total_tokens / step_time) + + if step_flops_info is not None: + step_tflops = step_flops_info['total_flops'] / 1e12 + step_throughput = step_tflops / step_time if step_time > 0 else 0 + metrics['inference/step_flops_tflop'] = float(step_tflops) + metrics['inference/throughput_tflops_per_gpu'] = float(step_throughput) + metrics['inference/t_avg'] = float(step_flops_info['t_avg']) + metrics['inference/cumulative_flops_tflop'] = float( + self.cumulative_inference_flops / 1e12 + ) + if self.gpu_peak_tflops > 0: + mfu = step_throughput / self.gpu_peak_tflops * 100.0 + cumulative_throughput = ( + (self.cumulative_inference_flops / 1e12) / self.cumulative_inference_time + if self.cumulative_inference_time > 0 + else 0 + ) + cumulative_mfu = cumulative_throughput / self.gpu_peak_tflops * 100.0 + metrics['inference/mfu_percent'] = float(mfu) + metrics['inference/cumulative_mfu_percent'] = float(cumulative_mfu) + # Add KV stats with inference/ prefix # Convert utilization metrics from 0-1 range to 0-100 percentage range for better visualization for key, value in context_state["kv_stats"].items(): @@ -1742,6 +1814,18 @@ async def async_bookkeep( self._spec_tokens_proposed, self._spec_steps, ) + batch_dims = self.context.batch_dimensions + total_tokens = batch_dims.token_count if batch_dims else 0 + if step_time > 0 and total_tokens > 0: + toks_per_sec_per_gpu = total_tokens / step_time + output_str += f" toks/s/GPU: {toks_per_sec_per_gpu:.0f}," + if step_flops_info is not None: + step_tflops = step_flops_info['total_flops'] / 1e12 + step_throughput = step_tflops / step_time if step_time > 0 else 0 + output_str += f" {step_throughput:.1f} TFLOP/s/GPU" + if self.gpu_peak_tflops > 0: + mfu = step_throughput / self.gpu_peak_tflops * 100.0 + output_str += f", MFU: {mfu:.1f}%" if context_state["is_decode_only"]: output_str = f"\033[94m{output_str}\033[0m" logging.info(output_str) diff --git a/megatron/core/inference/inference_flops.py b/megatron/core/inference/inference_flops.py new file mode 100644 index 00000000000..1b783a49bee --- /dev/null +++ b/megatron/core/inference/inference_flops.py @@ -0,0 +1,245 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Inference FLOPs calculator for hybrid Mamba/Attention/MoE models. + +Computes forward-pass FLOPs per inference step using model architecture +parameters. Used by the dynamic inference engine to report per-step +FLOPs and MFU (Model FLOPs Utilization). + +Reference: nemotron6_3b_moe_flops_equations.md +""" + +import logging +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + + +@dataclass +class InferenceFLOPsConfig: + """Model architecture parameters needed for FLOPs calculation.""" + + hidden_size: int = 0 + padded_vocab_size: int = 0 + num_attention_heads: int = 0 + num_query_groups: int = 0 + kv_channels: int = 128 + mamba_num_heads: int = 0 + mamba_head_dim: int = 64 + mamba_state_dim: int = 128 + mamba_num_groups: int = 8 + d_conv: int = 4 + num_experts: int = 0 + moe_router_topk: int = 1 + moe_ffn_hidden_size: int = 0 + moe_shared_expert_intermediate_size: int = 0 + ffn_hidden_size: int = 0 + swiglu: bool = False + num_mamba_layers: int = 0 + num_attention_layers: int = 0 + num_moe_layers: int = 0 + num_mlp_layers: int = 0 + block_size: int = 256 + + +class InferenceFLOPsCalculator: + """Computes forward-pass FLOPs per inference step. + + The calculator precomputes constant FLOPs terms at init time and provides + a fast `compute_step_flops()` method for per-step calculation. + """ + + def __init__(self, config: InferenceFLOPsConfig): + self.config = config + h = config.hidden_size + + # Mamba layer FLOPs per token (constant, no seq-length dependence) + d_inner = config.mamba_num_heads * config.mamba_head_dim + in_proj_dim = ( + 2 * d_inner + + 2 * config.mamba_num_groups * config.mamba_state_dim + + config.mamba_num_heads + ) + conv_channels = d_inner + 2 * config.mamba_num_groups * config.mamba_state_dim + + self.f_mamba_per_token = ( + 2 * h * in_proj_dim # in_proj + + 2 * conv_channels * config.d_conv # conv1d + + 5 * config.mamba_num_heads * config.mamba_state_dim * config.mamba_head_dim # SSM + + 2 * d_inner * h # out_proj + ) + + # Attention layer FLOPs per token (fixed part, excluding Q·K^T and attn·V) + qkv_dim = ( + config.num_attention_heads * config.kv_channels + + 2 * config.num_query_groups * config.kv_channels + ) + q_proj_size = config.num_attention_heads * config.kv_channels + self.f_attn_fixed_per_token = ( + 2 * h * qkv_dim + 2 * q_proj_size * h # QKV projection # output projection + ) + + # Attention variable FLOPs coefficient: 4 * n_h * d_h per position + self.f_attn_per_t = 4 * config.num_attention_heads * config.kv_channels + + # MoE layer FLOPs per token + scale_factor = 3.0 / 2.0 if config.swiglu else 1.0 + moe_ffn = ( + config.moe_ffn_hidden_size if config.moe_ffn_hidden_size else config.ffn_hidden_size + ) + self.f_moe_per_token = ( + 2 * h * config.num_experts # router + + 4 * h * moe_ffn * config.moe_router_topk * scale_factor # routed experts fc1+fc2 + + 4 + * h + * config.moe_shared_expert_intermediate_size + * scale_factor # shared expert fc1+fc2 + ) + + # Dense MLP layer FLOPs per token (for hybrid models with '-' pattern layers) + self.f_mlp_per_token = 4 * h * config.ffn_hidden_size * scale_factor + + # Output layer FLOPs per token + self.f_output_per_token = 2 * h * config.padded_vocab_size + + # Total fixed FLOPs per token (no attention variable term) + self.f_fixed_per_token = ( + config.num_mamba_layers * self.f_mamba_per_token + + config.num_attention_layers * self.f_attn_fixed_per_token + + config.num_moe_layers * self.f_moe_per_token + + config.num_mlp_layers * self.f_mlp_per_token + + self.f_output_per_token + ) + + # Total attention variable coefficient per token + self.f_attn_var_coeff = config.num_attention_layers * self.f_attn_per_t + + self.block_size = config.block_size + + logger.info( + f"InferenceFLOPsCalculator initialized: " + f"F_fixed={self.f_fixed_per_token/1e9:.2f}B/tok, " + f"F_attn_var={self.f_attn_var_coeff:,}/t, " + f"layers: {config.num_mamba_layers}M+{config.num_attention_layers}A+" + f"{config.num_moe_layers}E+{config.num_mlp_layers}D" + ) + + def compute_step_flops( + self, + decode_tokens: int, + prefill_tokens: int, + total_tokens: int, + active_blocks: int, + active_reqs: int, + num_prefill_reqs: int = 0, + ) -> dict: + """Compute FLOPs for a single inference step. + + Args: + decode_tokens: Number of decode tokens (= number of decode requests). + prefill_tokens: Number of prefill tokens (= total_tokens - decode_tokens). + total_tokens: Total tokens processed this step. + active_blocks: Number of active KV-cache blocks. + active_reqs: Number of active requests. + num_prefill_reqs: Number of prefill requests. + + Returns: + dict with 'decode_flops', 'prefill_flops', 'total_flops', 't_avg'. + """ + # Estimate average sequence position from KV-cache blocks + t_avg = (active_blocks * self.block_size) / max(active_reqs, 1) if active_reqs > 0 else 0 + + # Decode FLOPs: each decode token sees t_avg context + decode_flops = decode_tokens * (self.f_fixed_per_token + self.f_attn_var_coeff * t_avg) + + # Prefill FLOPs: linear term + quadratic attention term + prefill_flops = 0.0 + if prefill_tokens > 0: + prefill_flops_linear = prefill_tokens * self.f_fixed_per_token + if num_prefill_reqs > 0: + avg_prompt_len = prefill_tokens / num_prefill_reqs + prefill_attn_quad = ( + self.config.num_attention_layers + * num_prefill_reqs + * 2 + * self.config.num_attention_heads + * self.config.kv_channels + * avg_prompt_len**2 + ) + else: + prefill_attn_quad = 0 + prefill_flops = prefill_flops_linear + prefill_attn_quad + + total_flops = decode_flops + prefill_flops + return { + 'decode_flops': decode_flops, + 'prefill_flops': prefill_flops, + 'total_flops': total_flops, + 't_avg': t_avg, + } + + @classmethod + def from_args(cls, args) -> "InferenceFLOPsCalculator": + """Create calculator from megatron args (get_args()). + + Automatically detects layer counts from hybrid_override_pattern. + """ + num_attn = 0 + num_mamba = 0 + num_mlp = 0 + num_moe = 0 + + if getattr(args, 'hybrid_override_pattern', None): + from megatron.core.ssm.mamba_hybrid_layer_allocation import parse_hybrid_pattern + + parsed = parse_hybrid_pattern(args.hybrid_override_pattern) + counts = {'M': 0, '*': 0, '-': 0, 'E': 0} + if parsed.main_pattern: + for lt in parsed.main_pattern: + if lt in counts: + counts[lt] += 1 + num_attn, num_mamba, num_mlp, num_moe = ( + counts['*'], + counts['M'], + counts['-'], + counts['E'], + ) + elif getattr(args, 'is_hybrid_model', False): + num_attn = round(args.num_layers * args.hybrid_attention_ratio) + num_mlp = round(args.num_layers * args.hybrid_mlp_ratio) + num_mamba = args.num_layers - num_attn - num_mlp + else: + num_attn = args.num_layers + num_mamba = 0 + num_mlp = 0 + num_moe = 0 + + block_size = getattr(args, 'inference_dynamic_batching_block_size', 256) + + config = InferenceFLOPsConfig( + hidden_size=args.hidden_size, + padded_vocab_size=args.padded_vocab_size, + num_attention_heads=args.num_attention_heads, + num_query_groups=getattr(args, 'num_query_groups', args.num_attention_heads), + kv_channels=getattr(args, 'kv_channels', args.hidden_size // args.num_attention_heads), + mamba_num_heads=getattr(args, 'mamba_num_heads', 0) or 0, + mamba_head_dim=getattr(args, 'mamba_head_dim', 64) or 64, + mamba_state_dim=getattr(args, 'mamba_state_dim', 128) or 128, + mamba_num_groups=getattr(args, 'mamba_num_groups', 8) or 8, + d_conv=getattr(args, 'mamba_d_conv', 4) or 4, + num_experts=getattr(args, 'num_experts', 0) or 0, + moe_router_topk=getattr(args, 'moe_router_topk', 1) or 1, + moe_ffn_hidden_size=getattr(args, 'moe_ffn_hidden_size', 0) or 0, + moe_shared_expert_intermediate_size=getattr( + args, 'moe_shared_expert_intermediate_size', 0 + ) + or 0, + ffn_hidden_size=args.ffn_hidden_size, + swiglu=getattr(args, 'swiglu', False), + num_mamba_layers=num_mamba, + num_attention_layers=num_attn, + num_moe_layers=num_moe, + num_mlp_layers=num_mlp, + block_size=block_size, + ) + return cls(config) diff --git a/megatron/rl/rl_utils.py b/megatron/rl/rl_utils.py index d68a0330989..0eadc43d9c2 100644 --- a/megatron/rl/rl_utils.py +++ b/megatron/rl/rl_utils.py @@ -1371,6 +1371,20 @@ def prepare_data_for_update( dataset = TensorDataset(torch.arange(len(compute_trajs))) data_loader = DataLoader(dataset, batch_size=1) logprobs_batch_size = 1 + + my_real_tokens = sum( + packing_context.packing_info.seq_lengths[idx] + for indices in packing_context.packing_info.bin_seq_indices + for idx in indices + ) + real_tokens_tensor = torch.tensor([my_real_tokens], dtype=torch.long, device='cuda') + torch.distributed.all_reduce(real_tokens_tensor, group=mpu.get_data_parallel_group()) + global_real_tokens = real_tokens_tensor.item() + try: + from megatron.training.mfu_tracker import get_mfu_tracker + get_mfu_tracker().set_iter_real_training_tokens(global_real_tokens) + except Exception: + pass else: # Always compute standard masks for the original data (we'll need them later) with nvtx_range("get_ltor_masks_and_position_ids"): @@ -1392,6 +1406,19 @@ def prepare_data_for_update( ) logprobs_batch_size = args.micro_batch_size + # Without sequence packing, training.py defaults to GBS*seq_length which + # counts padding tokens and inflates TPS metrics. Report only the real + # (non-padding) tokens so the metric is comparable to the SP path. + my_real_tokens = int((trajs != tokenizer.pad).sum().item()) + real_tokens_tensor = torch.tensor([my_real_tokens], dtype=torch.long, device='cuda') + torch.distributed.all_reduce(real_tokens_tensor, group=mpu.get_data_parallel_group()) + global_real_tokens = real_tokens_tensor.item() + try: + from megatron.training.mfu_tracker import get_mfu_tracker + get_mfu_tracker().set_iter_real_training_tokens(global_real_tokens) + except Exception: + pass + with torch.no_grad(), nvtx_range("compute_logprobs", time=True): # Before we can update the model, we need to get the logprobs for the \pi_{old} model. diff --git a/megatron/training/gpu_peak_flops.py b/megatron/training/gpu_peak_flops.py new file mode 100644 index 00000000000..c90ec0cad20 --- /dev/null +++ b/megatron/training/gpu_peak_flops.py @@ -0,0 +1,71 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""GPU peak FLOPS lookup for MFU (Model FLOPs Utilization) calculation. + +Peak BF16 dense tensor core FLOPS for supported GPU architectures. +These values are per-GPU and do NOT include structured sparsity. +""" + +import logging +import torch + +logger = logging.getLogger(__name__) + +# BF16 dense tensor core peak TFLOP/s per GPU +# Dense BF16 tensor core peak (no structured sparsity). +# Sparse (2:4) peak is 2× these values. +_GPU_PEAK_TFLOPS_BF16 = { + "A100-SXM": 312, + "A100-PCIE": 312, + "H100-SXM": 989, + "H100-PCIE": 756, + "H200": 989, + "B100": 1750, + "B200": 2250, + "GB200": 2250, +} + + +def get_gpu_peak_tflops(dtype=torch.bfloat16): + """Detect GPU model and return peak TFLOP/s for the given dtype. + + Currently only supports BF16. Returns the dense (non-sparse) tensor core peak. + Falls back to a conservative estimate based on GPU architecture generation + if the specific model is not found. + + Returns: + float: Peak TFLOP/s for one GPU, or 0.0 if detection fails. + """ + if dtype != torch.bfloat16: + logger.warning(f"MFU peak lookup only supports BF16, got {dtype}. Returning 0.") + return 0.0 + + try: + device_name = torch.cuda.get_device_name(0) + except Exception: + logger.warning("Could not detect GPU. MFU will not be reported.") + return 0.0 + + name_upper = device_name.upper() + + for key, tflops in _GPU_PEAK_TFLOPS_BF16.items(): + if key.upper().replace("-", "") in name_upper.replace("-", "").replace(" ", ""): + logger.info(f"GPU detected: {device_name} -> peak BF16: {tflops} TFLOP/s") + return float(tflops) + + # Fallback by architecture generation + try: + major = torch.cuda.get_device_properties(0).major + fallback = {8: 312, 9: 989, 10: 2250} # Ampere, Hopper, Blackwell + if major in fallback: + tflops = fallback[major] + logger.warning( + f"GPU '{device_name}' not in lookup table. " + f"Using arch-generation fallback (sm_{major}0): {tflops} TFLOP/s" + ) + return float(tflops) + except Exception: + pass + + logger.warning(f"Unknown GPU '{device_name}'. MFU will not be reported.") + return 0.0 diff --git a/megatron/training/mfu_tracker.py b/megatron/training/mfu_tracker.py new file mode 100644 index 00000000000..a6bbba3a473 --- /dev/null +++ b/megatron/training/mfu_tracker.py @@ -0,0 +1,192 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Global MFU (Model FLOPs Utilization) tracker for RL training. + +Tracks cumulative FLOPs and wall-clock time for both inference and training +phases, enabling combined MFU reporting. + +Usage: + from megatron.training.mfu_tracker import get_mfu_tracker + tracker = get_mfu_tracker() + tracker.add_inference_flops(flops, time_s) + tracker.add_training_flops(flops, time_s) + report = tracker.get_report(gpu_peak_tflops) +""" + +import threading + + +class MFUTracker: + """Thread-safe tracker for inference and training FLOPs/time.""" + + def __init__(self): + self._lock = threading.Lock() + self.reset() + + def reset(self): + with self._lock: + self._inference_flops = 0.0 + self._inference_time = 0.0 + self._inference_tokens = 0 + self._training_flops = 0.0 + self._training_time = 0.0 + self._training_tokens = 0 + # Per-iteration accumulators (reset each RL iteration) + self._iter_inference_flops = 0.0 + self._iter_inference_time = 0.0 + self._iter_inference_tokens = 0 + self._iter_logprob_time = 0.0 + self._iter_real_training_tokens = 0 + + def add_inference_flops(self, flops: float, time_s: float, tokens: int = 0): + """Called by the inference engine each step.""" + with self._lock: + self._inference_flops += flops + self._inference_time += time_s + self._inference_tokens += tokens + self._iter_inference_flops += flops + self._iter_inference_time += time_s + self._iter_inference_tokens += tokens + + def add_training_flops(self, flops: float, time_s: float, tokens: int = 0): + """Called by the training loop each iteration.""" + with self._lock: + self._training_flops += flops + self._training_time += time_s + self._training_tokens += tokens + + def get_iter_inference_flops(self) -> float: + """Get inference FLOPs accumulated since last reset_iter().""" + with self._lock: + return self._iter_inference_flops + + def get_iter_inference_time(self) -> float: + """Get inference time accumulated since last reset_iter().""" + with self._lock: + return self._iter_inference_time + + def get_iter_inference_tokens(self) -> int: + """Get inference tokens accumulated since last reset_iter().""" + with self._lock: + return self._iter_inference_tokens + + def add_logprob_time(self, time_s: float): + """Called after the compute-logprobs phase each RL iteration.""" + with self._lock: + self._iter_logprob_time += time_s + + def get_iter_logprob_time(self) -> float: + with self._lock: + return self._iter_logprob_time + + def set_iter_real_training_tokens(self, tokens: int): + """Set the real (non-padding) training token count for this iteration.""" + with self._lock: + self._iter_real_training_tokens = tokens + + def get_iter_real_training_tokens(self) -> int: + with self._lock: + return self._iter_real_training_tokens + + def reset_iter(self): + """Reset per-iteration accumulators.""" + with self._lock: + self._iter_inference_flops = 0.0 + self._iter_inference_time = 0.0 + self._iter_inference_tokens = 0 + self._iter_logprob_time = 0.0 + self._iter_real_training_tokens = 0 + + def save_iter(self) -> dict: + """Snapshot per-iteration accumulators so they can be restored later. + + Used around evaluation to prevent eval inference from polluting + training throughput metrics. + """ + with self._lock: + return { + 'inference_flops': self._iter_inference_flops, + 'inference_time': self._iter_inference_time, + 'inference_tokens': self._iter_inference_tokens, + 'logprob_time': self._iter_logprob_time, + 'real_training_tokens': self._iter_real_training_tokens, + } + + def restore_iter(self, snapshot: dict): + """Restore per-iteration accumulators from a previous snapshot.""" + with self._lock: + self._iter_inference_flops = snapshot['inference_flops'] + self._iter_inference_time = snapshot['inference_time'] + self._iter_inference_tokens = snapshot['inference_tokens'] + self._iter_logprob_time = snapshot['logprob_time'] + self._iter_real_training_tokens = snapshot['real_training_tokens'] + + def get_report(self, gpu_peak_tflops: float) -> dict: + """Compute MFU breakdown. + + All FLOPs stored in this tracker are per-GPU. + + Args: + gpu_peak_tflops: Peak BF16 TFLOP/s for one GPU. + + Returns: + dict with keys: inference_tflops, inference_time, inference_mfu, + training_tflops, training_time, training_mfu, + total_tflops, total_time, total_mfu. + """ + with self._lock: + inf_tflops = self._inference_flops / 1e12 + inf_time = self._inference_time + inf_tokens = self._inference_tokens + train_tflops = self._training_flops / 1e12 + train_time = self._training_time + train_tokens = self._training_tokens + + total_tflops = inf_tflops + train_tflops + total_time = inf_time + train_time + total_tokens = inf_tokens + train_tokens + + def _mfu(tflops, time_s): + if time_s <= 0 or gpu_peak_tflops <= 0: + return 0.0 + return tflops / time_s / gpu_peak_tflops * 100.0 + + def _toks_per_sec(tokens, time_s): + if time_s <= 0: + return 0.0 + return tokens / time_s + + return { + 'inference_tflops': inf_tflops, + 'inference_time': inf_time, + 'inference_tokens': inf_tokens, + 'inference_throughput': inf_tflops / inf_time if inf_time > 0 else 0, + 'inference_mfu': _mfu(inf_tflops, inf_time), + 'inference_toks_per_sec_per_gpu': _toks_per_sec(inf_tokens, inf_time), + 'training_tflops': train_tflops, + 'training_time': train_time, + 'training_tokens': train_tokens, + 'training_throughput': train_tflops / train_time if train_time > 0 else 0, + 'training_mfu': _mfu(train_tflops, train_time), + 'training_toks_per_sec_per_gpu': _toks_per_sec(train_tokens, train_time), + 'total_tflops': total_tflops, + 'total_time': total_time, + 'total_tokens': total_tokens, + 'total_throughput': total_tflops / total_time if total_time > 0 else 0, + 'total_mfu': _mfu(total_tflops, total_time), + 'total_toks_per_sec_per_gpu': _toks_per_sec(total_tokens, total_time), + } + + +_GLOBAL_TRACKER = None +_TRACKER_LOCK = threading.Lock() + + +def get_mfu_tracker() -> MFUTracker: + """Get or create the global MFU tracker singleton.""" + global _GLOBAL_TRACKER + if _GLOBAL_TRACKER is None: + with _TRACKER_LOCK: + if _GLOBAL_TRACKER is None: + _GLOBAL_TRACKER = MFUTracker() + return _GLOBAL_TRACKER diff --git a/megatron/training/training.py b/megatron/training/training.py index 468f46622a1..1170b7c30ed 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -65,6 +65,7 @@ def set_startup_timestamps(program_start=None, main_entry=None): has_rl_utils = True except ImportError: has_rl_utils = False + from megatron.rl.parallel_utils import build_inference_pg_collection try: from modelopt.torch.distill.plugins.megatron import ( @@ -2125,11 +2126,131 @@ def training_log( ) if args.log_throughput: log_string += f' throughput per GPU (TFLOP/s/GPU): {throughput:.1f} |' + + tokens_this_iter = batch_size * args.seq_length + + # Compute and log MFU (Model FLOPs Utilization) + if not hasattr(args, '_gpu_peak_tflops'): + try: + from megatron.training.gpu_peak_flops import get_gpu_peak_tflops + args._gpu_peak_tflops = get_gpu_peak_tflops() + except Exception: + args._gpu_peak_tflops = 0.0 + + training_mfu = 0.0 + inference_mfu = 0.0 + total_mfu = 0.0 + has_tracker = False + iter_inference_tokens = 0 + iter_inference_time = 0.0 + iter_logprob_time = 0.0 + training_only_time = elapsed_time_per_iteration + training_flops = 0.0 + iter_inference_flops = 0.0 + effective_tokens = tokens_this_iter + + # Read compute-logprobs time from the existing Megatron timer + try: + iter_logprob_time = ( + timers('compute-logprobs').elapsed(reset=False, barrier=False) + / total_iterations + ) + except Exception: + pass + + if args._gpu_peak_tflops > 0: + try: + from megatron.training.mfu_tracker import get_mfu_tracker + tracker = get_mfu_tracker() + # Normalize to per-GPU to match inference FLOPs (already per-GPU). + training_flops = num_floating_point_operations(args, batch_size) / args.world_size + # Normalize tracker totals to per-iteration averages (the + # tracker accumulates across the entire log_interval window, + # but training_only_time is computed per-iteration). + iter_inference_time = tracker.get_iter_inference_time() / total_iterations + iter_inference_flops = tracker.get_iter_inference_flops() / total_iterations + iter_inference_tokens = tracker.get_iter_inference_tokens() / total_iterations + real_training_tokens = tracker.get_iter_real_training_tokens() + if real_training_tokens > 0: + effective_tokens = real_training_tokens + training_only_time = max( + elapsed_time_per_iteration - iter_inference_time - iter_logprob_time, 1e-6 + ) + tracker.add_training_flops( + training_flops, training_only_time, tokens=effective_tokens + ) + tracker.reset_iter() + has_tracker = True + except Exception: + has_tracker = False + + training_mfu = throughput / args._gpu_peak_tflops * 100.0 + + ws = args.world_size + + # Per-iteration toks/s/GPU breakdown (uses real tokens when seq packing is active) + train_tps = effective_tokens / (training_only_time * ws) if training_only_time > 0 else 0.0 + inf_tps = iter_inference_tokens / (iter_inference_time * ws) if iter_inference_time > 0 else 0.0 + total_tps = (effective_tokens + iter_inference_tokens) / (elapsed_time_per_iteration * ws) if elapsed_time_per_iteration > 0 else 0.0 + e2e_tps = effective_tokens / (elapsed_time_per_iteration * ws) if elapsed_time_per_iteration > 0 else 0.0 + + if has_tracker: + log_string += ( + f' toks/s/GPU: train {train_tps:.0f}' + f', infer {inf_tps:.0f}' + f', total {total_tps:.0f}' + f', e2e {e2e_tps:.0f} |' + ) + + # Per-iteration MFU breakdown + if args._gpu_peak_tflops > 0: + log_string += f' MFU: train {training_mfu:.1f}%' + if has_tracker: + if iter_inference_time > 0: + inference_mfu = ( + iter_inference_flops / iter_inference_time + / 1e12 / args._gpu_peak_tflops * 100.0 + ) + if elapsed_time_per_iteration > 0: + total_mfu = ( + (training_flops + iter_inference_flops) + / elapsed_time_per_iteration + / 1e12 / args._gpu_peak_tflops * 100.0 + ) + log_string += ( + f', infer {inference_mfu:.1f}%' + f', total {total_mfu:.1f}%' + ) + log_string += ' |' + if args.log_timers_to_tensorboard: if writer: writer.add_scalar('throughput', throughput, iteration) + writer.add_scalar('toks_per_sec_per_gpu/e2e', e2e_tps, iteration) + if has_tracker: + writer.add_scalar('toks_per_sec_per_gpu/training', train_tps, iteration) + writer.add_scalar('toks_per_sec_per_gpu/inference', inf_tps, iteration) + writer.add_scalar('toks_per_sec_per_gpu/total', total_tps, iteration) + if args._gpu_peak_tflops > 0: + writer.add_scalar('mfu/training_percent', training_mfu, iteration) + if has_tracker: + writer.add_scalar('mfu/inference_percent', inference_mfu, iteration) + writer.add_scalar('mfu/total_percent', total_mfu, iteration) if wandb_writer: - wandb_writer.log({'throughput': throughput}, iteration) + wandb_log = { + 'throughput': throughput, + 'toks_per_sec_per_gpu/e2e': e2e_tps, + } + if has_tracker: + wandb_log['toks_per_sec_per_gpu/training'] = train_tps + wandb_log['toks_per_sec_per_gpu/inference'] = inf_tps + wandb_log['toks_per_sec_per_gpu/total'] = total_tps + if args._gpu_peak_tflops > 0: + wandb_log['mfu/training_percent'] = training_mfu + if has_tracker: + wandb_log['mfu/inference_percent'] = inference_mfu + wandb_log['mfu/total_percent'] = total_mfu + wandb_writer.log(wandb_log, iteration) if args.log_energy: energy = (energy_monitor.lap() / total_iterations) / args.world_size power = energy / elapsed_time_per_iteration @@ -3049,6 +3170,17 @@ def trace_handler(p): if args.log_energy: energy_monitor.pause() timers('interval-time').stop() + + # Save MFU tracker per-iteration state before eval so that + # eval inference (which goes through the same engine) does not + # pollute training throughput / MFU metrics. + _mfu_snapshot = None + try: + from megatron.training.mfu_tracker import get_mfu_tracker + _mfu_snapshot = get_mfu_tracker().save_iter() + except Exception: + pass + if should_disable_forward_pre_hook(args): disable_forward_pre_hook(model) pre_hook_enabled = False @@ -3090,6 +3222,13 @@ def trace_handler(p): timers('eval-time').stop() one_logger_utils.track_e2e_metrics() + # Restore MFU tracker state to discard eval's inference contributions. + if _mfu_snapshot is not None: + try: + get_mfu_tracker().restore_iter(_mfu_snapshot) + except Exception: + pass + if args.manual_gc and args.manual_gc_eval: # Collect only the objects created and used in evaluation. gc.collect(generation=0) diff --git a/tests/unit_tests/rl/test_mfu_tracking.py b/tests/unit_tests/rl/test_mfu_tracking.py new file mode 100644 index 00000000000..1fe25ea23b2 --- /dev/null +++ b/tests/unit_tests/rl/test_mfu_tracking.py @@ -0,0 +1,128 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +from types import SimpleNamespace +from unittest.mock import patch + +import pytest + +from megatron.core.inference.inference_flops import InferenceFLOPsCalculator +from megatron.training.mfu_tracker import MFUTracker + + +class TestMFUTracking: + """End-to-end tests for inference FLOPs calculation and MFU tracking.""" + + @pytest.mark.parametrize( + "decode_tokens,prefill_tokens,active_blocks,active_reqs,num_prefill_reqs", + [ + (32, 0, 4, 32, 0), # decode-only + (0, 128, 2, 2, 2), # prefill-only + (16, 64, 4, 18, 2), # mixed + (0, 0, 0, 0, 0), # empty step + ], + ids=["decode", "prefill", "mixed", "empty"], + ) + def test_inference_flops_e2e( + self, decode_tokens, prefill_tokens, active_blocks, active_reqs, num_prefill_reqs + ): + """InferenceFLOPsCalculator.from_args -> compute_step_flops produces consistent results.""" + args = SimpleNamespace( + hidden_size=256, + padded_vocab_size=1024, + num_attention_heads=8, + num_query_groups=4, + kv_channels=32, + ffn_hidden_size=512, + num_layers=8, + swiglu=True, + hybrid_override_pattern="M*M*-EMM*", + num_experts=4, + moe_router_topk=2, + moe_ffn_hidden_size=512, + moe_shared_expert_intermediate_size=256, + mamba_num_heads=4, + mamba_head_dim=64, + mamba_state_dim=128, + mamba_num_groups=8, + mamba_d_conv=4, + inference_dynamic_batching_block_size=256, + ) + calc = InferenceFLOPsCalculator.from_args(args) + total_tokens = decode_tokens + prefill_tokens + result = calc.compute_step_flops( + decode_tokens=decode_tokens, + prefill_tokens=prefill_tokens, + total_tokens=total_tokens, + active_blocks=active_blocks, + active_reqs=active_reqs, + num_prefill_reqs=num_prefill_reqs, + ) + assert result['total_flops'] == result['decode_flops'] + result['prefill_flops'] + if total_tokens == 0: + assert result['total_flops'] == 0.0 + else: + assert result['total_flops'] > 0 + if active_reqs > 0: + assert result['t_avg'] == active_blocks * 256 / active_reqs + # Prefill should grow super-linearly (quadratic attention term) + if prefill_tokens > 0 and num_prefill_reqs > 0: + r2 = calc.compute_step_flops( + decode_tokens=0, + prefill_tokens=prefill_tokens * 2, + total_tokens=prefill_tokens * 2, + active_blocks=active_blocks, + active_reqs=active_reqs, + num_prefill_reqs=num_prefill_reqs, + ) + assert r2['prefill_flops'] / result['prefill_flops'] > 2.0 + + @pytest.mark.parametrize( + "inf_flops,inf_time,inf_tokens,train_flops,train_time,train_tokens,peak", + [ + (50e12, 5.0, 1000, 50e12, 5.0, 2000, 100.0), # balanced + (0, 0, 0, 100e12, 10.0, 5000, 100.0), # training only + (100e12, 10.0, 3000, 0, 0, 0, 100.0), # inference only + (1e12, 1.0, 100, 1e12, 1.0, 100, 0.0), # zero peak + ], + ids=["balanced", "train-only", "infer-only", "zero-peak"], + ) + def test_mfu_tracker_e2e( + self, inf_flops, inf_time, inf_tokens, train_flops, train_time, train_tokens, peak + ): + """MFUTracker accumulates per-GPU FLOPs and computes correct MFU.""" + tracker = MFUTracker() + if inf_flops: + tracker.add_inference_flops(inf_flops, inf_time, tokens=inf_tokens) + if train_flops: + tracker.add_training_flops(train_flops, train_time, tokens=train_tokens) + + # Per-iteration accessors match what was added + assert tracker.get_iter_inference_flops() == inf_flops + assert tracker.get_iter_inference_time() == inf_time + assert tracker.get_iter_inference_tokens() == inf_tokens + + report = tracker.get_report(gpu_peak_tflops=peak) + total_time = inf_time + train_time + total_tflops = inf_flops / 1e12 + train_flops / 1e12 + + if peak > 0 and total_time > 0: + expected_mfu = total_tflops / total_time / peak * 100.0 + assert abs(report['total_mfu'] - expected_mfu) < 1e-9 + else: + assert report['total_mfu'] == 0.0 + + # No world_size division — FLOPs are already per-GPU + if total_time > 0: + assert abs(report['total_throughput'] - total_tflops / total_time) < 1e-9 + + # reset_iter clears per-iteration but keeps cumulative + tracker.reset_iter() + assert tracker.get_iter_inference_flops() == 0.0 + assert tracker.get_report(gpu_peak_tflops=peak)['total_tflops'] == total_tflops + + # save_iter / restore_iter round-trips correctly + tracker.add_inference_flops(1e12, 1.0, tokens=10) + snapshot = tracker.save_iter() + tracker.add_inference_flops(9e12, 9.0, tokens=90) + tracker.restore_iter(snapshot) + assert tracker.get_iter_inference_flops() == 1e12