diff --git a/modelopt/torch/kernels/quantization/gemm/__init__.py b/modelopt/torch/kernels/quantization/gemm/__init__.py index 70f729cffb0..1f22fdf98a9 100644 --- a/modelopt/torch/kernels/quantization/gemm/__init__.py +++ b/modelopt/torch/kernels/quantization/gemm/__init__.py @@ -38,4 +38,8 @@ if torch.cuda.get_device_capability() >= (8, 9): from .fp4_kernel_hopper import * + # OMNIML-5072 — per-expert axis-0 fake-quant via tensor-of-pointers. + # Generic Triton + CUDA; no special hardware. See VALIDATION_TODO.md. + from .grouped_axis0_fakequant import * + IS_AVAILABLE = True diff --git a/modelopt/torch/kernels/quantization/gemm/grouped_axis0_fakequant.py b/modelopt/torch/kernels/quantization/gemm/grouped_axis0_fakequant.py new file mode 100644 index 00000000000..9c6cc01959e --- /dev/null +++ b/modelopt/torch/kernels/quantization/gemm/grouped_axis0_fakequant.py @@ -0,0 +1,257 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Fused per-expert axis-0 fake-quant Triton kernels for TEGroupedLinear. + +Replaces the stack-then-quantize-then-unbind pattern in modelopt's TEGrouped +plugin (`te_grouped_quantized_linear_fn`) with a single Triton launch that +processes N expert weights in place, with no contiguous-tensor staging. + +Design — tensor of pointers +--------------------------- + +The N expert weights live as separate Parameters (one per expert), so they're +NOT contiguous in HBM. To avoid a `torch.stack` memcopy (the cost AC5 +characterized on OMNIML-5064), we feed the kernel a `[N]` int64 tensor of +expert base pointers. Each Triton program reads its expert's pointer first, +then strides through a block of elements at that address. + +Grid: (N, num_blocks_per_expert). +Program 0 of axis 0 → expert 0, program 1 → expert 1, etc. + +See OMNIML-5072 AC5 (Option B follow-up) for the motivation. + +VALIDATION STATUS (2026-06-11): kernel implemented, numerical fidelity NOT +yet validated against modelopt's reference `fake_quant_impl`, and bench +performance NOT yet measured. See VALIDATION_TODO.md in this directory. +""" + +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +__all__ = ["grouped_axis0_fakequant", "grouped_axis0_fakequant_backward"] + + +def _torch_dtype_to_tl(dtype: torch.dtype): + """Map a torch dtype to its Triton-language equivalent.""" + return { + torch.float32: tl.float32, + torch.bfloat16: tl.bfloat16, + torch.float16: tl.float16, + }[dtype] + + +@triton.jit +def _grouped_axis0_fakequant_fwd_kernel( + weight_ptrs_buf, # int64 [N] — N expert base pointers (cast from .data_ptr()) + output_ptrs_buf, # int64 [N] — N output base pointers + amax_vec_ptr, # [N, 1, 1] (or anything with N as the leading dim) + elements_per_expert, + num_bits, + narrow_range: tl.constexpr, + DTYPE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + expert_idx = tl.program_id(axis=0) + block_idx = tl.program_id(axis=1) + + # Per-expert base pointers (loaded once per program). + w_int = tl.load(weight_ptrs_buf + expert_idx) + out_int = tl.load(output_ptrs_buf + expert_idx) + w_ptr = w_int.to(tl.pointer_type(DTYPE)) + out_ptr = out_int.to(tl.pointer_type(DTYPE)) + + # Per-expert amax → quant scale. + # amax is stored as fp32; convert to working precision. + amax = tl.load(amax_vec_ptr + expert_idx).to(tl.float32) + # qmax = 2^(num_bits-1) - 1 when narrow_range else 2^(num_bits-1) + # For num_bits=8 narrow_range=True (modelopt default): qmax=127 + qmax = ((1 << (num_bits - 1)) - 1) if narrow_range else (1 << (num_bits - 1)) + qmin = -qmax if narrow_range else -qmax # signed symmetric + scale = amax / qmax + + # Block of elements within this expert. + offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < elements_per_expert + + x = tl.load(w_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + + # Fake-quant: round(clip(x / scale)) * scale. + # Use scale guards to avoid div-by-zero before _amax is calibrated (early + # batches may carry an _amax of 0; matches modelopt's fake_tensor_quant + # behavior of passing through unchanged). + safe_scale = tl.where(scale > 0.0, scale, 1.0) + q = x / safe_scale + q = tl.maximum(tl.minimum(q, qmax), qmin) + # Round-half-to-even (banker's), matching cuda_ext.fake_tensor_quant exactly. + # libdevice.rint is CUDA's __rint* builtin. Imported via the same path that + # modelopt's nvfp4_quant.py uses (triton.language.extra.cuda.libdevice). + q_rounded = libdevice.rint(q) + out = tl.where(scale > 0.0, q_rounded * scale, x) + + tl.store(out_ptr + offsets, out.to(DTYPE), mask=mask) + + +@triton.jit +def _grouped_axis0_fakequant_bwd_kernel( + weight_ptrs_buf, # int64 [N] — same buffer as fwd + grad_out_ptrs_buf, # int64 [N] — upstream grad pointers (per expert) + grad_in_ptrs_buf, # int64 [N] — output: downstream grad pointers + amax_vec_ptr, # [N, ...] — same buffer as fwd + elements_per_expert, + DTYPE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Clip-aware STE backward. + + For each expert i: + grad_in[i] = grad_out[i] if |w[i]| <= amax[i] else 0 + matches modelopt's `_fake_tensor_quant_backward` semantics. + """ + expert_idx = tl.program_id(axis=0) + block_idx = tl.program_id(axis=1) + + w_ptr = tl.load(weight_ptrs_buf + expert_idx).to(tl.pointer_type(DTYPE)) + grad_out_ptr = tl.load(grad_out_ptrs_buf + expert_idx).to(tl.pointer_type(DTYPE)) + grad_in_ptr = tl.load(grad_in_ptrs_buf + expert_idx).to(tl.pointer_type(DTYPE)) + + amax = tl.load(amax_vec_ptr + expert_idx).to(tl.float32) + + offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < elements_per_expert + + # Stay in DTYPE (bf16/fp16) throughout — eliminates fp32 round-trip seen in + # the Btriton2 baseline that capped bwd bandwidth at ~4.2 TB/s vs cuda_ext's + # ~8 TB/s on B300. amax cast to DTYPE once; comparison done in low precision + # (amax values are O(1)-O(10), well within bf16 range). + w = tl.load(w_ptr + offsets, mask=mask, other=0.0) + g = tl.load(grad_out_ptr + offsets, mask=mask, other=0.0) + amax_dt = amax.to(DTYPE) + + # Clip-aware STE: pass through gradient where |w| <= amax, else zero. + pass_through = tl.abs(w) <= amax_dt + grad_in = tl.where(pass_through, g, 0.0) + + tl.store(grad_in_ptr + offsets, grad_in, mask=mask) + + +def _build_ptr_buf(tensors: list[torch.Tensor]) -> torch.Tensor: + """Pack a list of tensors' .data_ptr() into a single int64 tensor on the same device.""" + return torch.tensor( + [t.data_ptr() for t in tensors], + dtype=torch.int64, + device=tensors[0].device, + ) + + +def grouped_axis0_fakequant( + weights: list[torch.Tensor], + amax_vec: torch.Tensor, + num_bits: int = 8, + narrow_range: bool = True, +) -> list[torch.Tensor]: + """Apply per-expert axis-0 fake-quant in a single Triton launch. + + Args: + weights: List of N expert weight tensors. Each must have the same shape + `[out, in]` and same dtype. + amax_vec: Per-expert amax buffer of shape `[N, 1, 1]` (or any shape where + element `i` is expert `i`'s amax). dtype should be float32 for + numerical headroom; the kernel casts to fp32 internally. + num_bits: integer bit-width for the fake-quant. + narrow_range: if True, output range is [-qmax, +qmax]; else [-qmax, +qmax-1]. + modelopt's default is True. + + Returns: + List of N quantized weight tensors, each the same shape and dtype as + the corresponding input. + """ + assert len(weights) >= 1, "grouped_axis0_fakequant requires at least one expert" + N = len(weights) + shape0 = weights[0].shape + dtype0 = weights[0].dtype + device0 = weights[0].device + elements_per_expert = weights[0].numel() + for w in weights[1:]: + assert w.shape == shape0, "all expert weights must share the same shape" + assert w.dtype == dtype0, "all expert weights must share the same dtype" + assert w.device == device0, "all expert weights must share the same device" + + outputs = [torch.empty_like(w) for w in weights] + + weight_ptrs = _build_ptr_buf(weights) + output_ptrs = _build_ptr_buf(outputs) + + # BLOCK_SIZE=2048 was empirically best in the Btriton2 sweep — larger blocks + # (16384 + num_warps=8) regressed both fwd and bwd, likely from worse warp + # occupancy and load coalescing on B300. + BLOCK_SIZE = 2048 + num_blocks_per_expert = triton.cdiv(elements_per_expert, BLOCK_SIZE) + grid = (N, num_blocks_per_expert) + + with torch.cuda.device(device0): + _grouped_axis0_fakequant_fwd_kernel[grid]( + weight_ptrs, + output_ptrs, + amax_vec, + elements_per_expert, + num_bits, + narrow_range=narrow_range, + DTYPE=_torch_dtype_to_tl(dtype0), + BLOCK_SIZE=BLOCK_SIZE, + ) + + return outputs + + +def grouped_axis0_fakequant_backward( + weights: list[torch.Tensor], + grad_outputs: list[torch.Tensor], + amax_vec: torch.Tensor, +) -> list[torch.Tensor]: + """Apply per-expert clip-aware STE backward in a single Triton launch. + + Matches modelopt's `_fake_tensor_quant_backward` semantics — gradient + passes through where `|w[i]| <= amax[i]`, else zero. + + Args: + weights: List of N expert weight tensors (the original fwd inputs). + grad_outputs: List of N upstream gradients, one per expert. + amax_vec: Per-expert amax buffer (same shape as in fwd). + + Returns: + List of N downstream gradients, one per expert. + """ + N = len(weights) + assert len(grad_outputs) == N + shape0 = weights[0].shape + dtype0 = weights[0].dtype + device0 = weights[0].device + elements_per_expert = weights[0].numel() + + grad_inputs = [torch.empty_like(w) for w in weights] + + weight_ptrs = _build_ptr_buf(weights) + grad_out_ptrs = _build_ptr_buf(grad_outputs) + grad_in_ptrs = _build_ptr_buf(grad_inputs) + + BLOCK_SIZE = 2048 + num_blocks_per_expert = triton.cdiv(elements_per_expert, BLOCK_SIZE) + grid = (N, num_blocks_per_expert) + + with torch.cuda.device(device0): + _grouped_axis0_fakequant_bwd_kernel[grid]( + weight_ptrs, + grad_out_ptrs, + grad_in_ptrs, + amax_vec, + elements_per_expert, + DTYPE=_torch_dtype_to_tl(dtype0), + BLOCK_SIZE=BLOCK_SIZE, + ) + + return grad_inputs diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 0b50fd937ae..0c20e5dc8f4 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -40,7 +40,7 @@ ) from modelopt.torch.utils.distributed import ParallelState -from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer +from ..nn import QuantModule, QuantModuleRegistry, SequentialQuantizer, TensorQuantizer from ..nn.modules.quant_linear import RealQuantLinear from ..qtensor import QTensorWrapper from ..utils import sync_moe_expert_amax @@ -638,6 +638,62 @@ class _QuantTELayerNormColumnParallelLinear( # Quantized subclasses to support TEGroupedMLP quantization class _QuantMegatronTEGroupedLinear(_QuantTEGroupedLinear, _MegatronParallelLinear): + def _ep_group(self): + # Return the expert_model_parallel_group iff it's initialized AND has >1 rank. + # _MegatronTEGroupedMLP._setup populates parallel_state with the EP group; + # outside that wrapping it may be unset (e.g. ad-hoc unit tests). + ps = getattr(self, "parallel_state", None) + if ps is None: + return None + ep = ps.expert_model_parallel_group + if not ep.is_initialized() or ep.world_size() <= 1: + return None + return ep + + def _gather_global_per_expert_amax_n_modules(self): + # Per-expert N-modules path (OMNIML-5072 AC4): each rank's N local + # submodules (weight_quantizer_0..weight_quantizer_{N-1}) carry one + # scalar _amax each. Stack them into [N_local] and all-gather across + # the EP group to get the global [N_global] vector. + # + # Returns None when not in per-expert mode or any quantizer lacks + # _amax (caller falls back to the legacy per-quantizer path). + if not getattr(self, "_per_expert_weight_quantizer", False): + return None + local = [] + for i in range(self.num_gemms): + q = self._get_weight_quantizer(i) + q_inner = q[0] if isinstance(q, SequentialQuantizer) else q + amax = getattr(q_inner, "_amax", None) + if amax is None: + return None + local.append(amax.view(()).to(torch.float32)) + v_local = torch.stack(local) # [N_local] + ep = self._ep_group() + if ep is None: + return v_local # EP=1: local IS global. + gathered = [torch.empty_like(v_local) for _ in range(ep.world_size())] + torch.distributed.all_gather(gathered, v_local, group=ep.group) + return torch.cat(gathered, dim=0) # [N_global] + + def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): + # Gather the global per-expert amax ONCE before Megatron's save traversal. + # Stash on a temporary attribute that _process_quantizer_amax reads; the + # live per-expert weight_quantizer_i._amax scalars stay untouched so + # forward keeps working. + # + # Done at the top of sharded_state_dict (not inside _process_quantizer_amax) + # so the EP collective completes BEFORE Megatron's dist-checkpoint save + # kicks off its own default-PG ALLGATHER metadata exchanges. Interleaving + # EP gathers with default-PG collectives deadlocks NCCL. + self._cached_global_per_expert_amax_n_modules = ( + self._gather_global_per_expert_amax_n_modules() + ) + try: + return super().sharded_state_dict(prefix, sharded_offsets, metadata) + finally: + self._cached_global_per_expert_amax_n_modules = None + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): # _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in # sharded_state_dict which is same as _extra_state. The _extra_state{gemm_idx} is used for @@ -648,9 +704,44 @@ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): for k, v in state_dict.items() if not any(k.endswith(f"_extra_state{num}") for num in range(1, self.num_gemms)) } + # Per-expert N-modules path: each weight_quantizer_i._amax key in the + # saved state-dict carries the gathered [N_global] vector (replicated + # across all EP ranks). Narrow it to this rank's local scalar — pull + # element (ep_rank * N_local + i) out of the global tensor and + # reshape to scalar for the i-th submodule's _amax buffer. + if getattr(self, "_per_expert_weight_quantizer", False): + import re + ep = self._ep_group() + ep_rank = ep.rank() if ep is not None else 0 + ep_size = ep.world_size() if ep is not None else 1 + global_size = self.num_gemms * ep_size + offset = ep_rank * self.num_gemms + pattern = re.compile(r"weight_quantizer_(\d+)\._amax$") + for k in list(filtered_state_dict.keys()): + m = pattern.search(k) + if m is None: + continue + local_i = int(m.group(1)) + v = filtered_state_dict[k] + if v.numel() == global_size: + filtered_state_dict[k] = ( + v.view(global_size)[offset + local_i].view(()) + ) + # else: leave as-is (legacy or EP=1 save format). return super()._load_from_state_dict(filtered_state_dict, prefix, *args, **kwargs) def _process_quantizer_amax(self, k, v, quantizer_state_dict): + # Per-expert N-modules path: emit the gathered global per-expert + # amax vector under every weight_quantizer_i._amax key (replicated + # across EP ranks). On load, _load_from_state_dict narrows back to + # the local scalar per submodule. Suboptimal disk usage (N copies + # of the same vector per layer) but mirrors B's gather-once-cache + # pattern and avoids surgery into the base-class state-dict + # iteration. + cached = getattr(self, "_cached_global_per_expert_amax_n_modules", None) + if cached is not None and "weight_quantizer_" in k and k.endswith("_amax"): + quantizer_state_dict[k] = cached.view(cached.numel()) + return assert v.numel() == 1, "TEGroupedLinear only supports per-tensor quantization" quantizer_state_dict[k] = v.view(-1) diff --git a/modelopt/torch/quantization/plugins/transformer_engine.py b/modelopt/torch/quantization/plugins/transformer_engine.py index e3a87927fd3..2fda9633394 100644 --- a/modelopt/torch/quantization/plugins/transformer_engine.py +++ b/modelopt/torch/quantization/plugins/transformer_engine.py @@ -29,6 +29,7 @@ from modelopt.torch.quantization.utils import replace_function +import modelopt.torch.kernels.quantization.gemm as _triton_kernels from ..nn import QuantModuleRegistry, SequentialQuantizer from .custom import _ParallelLinear @@ -110,6 +111,39 @@ def te_quantized_linear_fn(package, func_name, self, *args, **kwargs): _quantized_linear_fn = te_quantized_linear_fn +class _GroupedAxis0FakeQuantFn(torch.autograd.Function): + """Triton-backed per-expert fake-quant adapter for the N-modules path. + + Forward: single-launch Triton kernel over N expert weights (tensor-of-pointers, + no stack memcopy). Backward honors modelopt's default `pass_through_bwd=True` + — gradient flows back unchanged with zero kernel work. When False, the + clip-aware Triton STE backward kernel runs. See + `modelopt/torch/kernels/quantization/gemm/grouped_axis0_fakequant.py`. + """ + + @staticmethod + def forward(ctx, amax_vec, num_bits, narrow_range, pass_through_bwd, *weights): + outputs = _triton_kernels.grouped_axis0_fakequant( + list(weights), amax_vec, num_bits=num_bits, narrow_range=narrow_range + ) + ctx.pass_through_bwd = pass_through_bwd + if not pass_through_bwd: + ctx.save_for_backward(amax_vec, *weights) + ctx.num_weights = len(weights) + return tuple(outputs) + + @staticmethod + def backward(ctx, *grad_outputs): + if ctx.pass_through_bwd: + return (None, None, None, None, *grad_outputs) + saved = ctx.saved_tensors + amax_vec, weights = saved[0], list(saved[1:]) + grad_inputs = _triton_kernels.grouped_axis0_fakequant_backward( + weights, list(grad_outputs), amax_vec + ) + return (None, None, None, None, *grad_inputs) + + # Register the public te.pytorch.GroupedLinear class @QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear"}) class _QuantTEGroupedLinear(_ParallelLinear): @@ -176,12 +210,78 @@ def modelopt_post_restore(self, prefix: str = ""): continue wq_i.reset_amax() max_calibrate(wq_i, lambda wq, w=weight_i: wq(w), distributed_sync=False) + # Re-calibration just changed every per-expert _amax. Drop the cache + # so the next forward rebuilds from the fresh values. + self._invalidate_per_expert_amax_cache() def _get_weight_quantizer(self, gemm_idx: int): if getattr(self, "_per_expert_weight_quantizer", False): return getattr(self, f"weight_quantizer_{gemm_idx}") return self.weight_quantizer + def _gather_per_expert_amax(self) -> torch.Tensor: + """Stack N per-expert weight_quantizer_i._amax scalars into a [N] fp32 vector. + + Matches the amax-input contract of `grouped_axis0_fakequant` — one + entry per expert, indexed by gemm_idx. + + Cached lazily: the per-expert _amax scalars don't change outside + calibration, and the gate (_can_use_triton_per_expert_path) only + admits this path when `q._if_calib` is False on every quantizer — + so once the cache is populated it stays valid for the lifetime of + the layer's calibrated state. The cache is invalidated explicitly + via _invalidate_per_expert_amax_cache (called from modelopt_post_restore) + in case checkpoint reload changes the amax values. + + Eliminates the O(N)-Python-overhead-per-forward walk over N submodules + observed in OMNIML-5072 AC3's microbench (the gap to Btriton5 grew + with N — 1.59x at N=32, 2.18x at N=128 — symptomatic of per-forward + scaling that disappears once the gathered tensor is reused). + """ + cached = getattr(self, "_per_expert_amax_cache", None) + if cached is not None: + return cached + amaxes = [] + for i in range(self.num_gemms): + q = self._get_weight_quantizer(i) + amaxes.append(q._amax.to(torch.float32).reshape(())) + stacked = torch.stack(amaxes).contiguous() + self._per_expert_amax_cache = stacked + return stacked + + def _invalidate_per_expert_amax_cache(self) -> None: + """Drop the cached _gather_per_expert_amax result. + + Called automatically from modelopt_post_restore (where dist-ckpt load + may have changed per-expert _amax buffers). Also callable by user code + after explicit re-calibration that mutates _amax outside the normal + calibration flow. + """ + if hasattr(self, "_per_expert_amax_cache"): + self._per_expert_amax_cache = None + + def _can_use_triton_per_expert_path(self, num_gemms: int) -> bool: + """Soft-gate the Triton dispatch on availability + ready-to-quantize state.""" + if not getattr(self, "_per_expert_weight_quantizer", False): + return False + if not _triton_kernels.IS_AVAILABLE: + return False + if not hasattr(_triton_kernels, "grouped_axis0_fakequant"): + return False + for i in range(num_gemms): + q = self._get_weight_quantizer(i) + # SequentialQuantizer (multi-stage) not supported on the Triton + # path; fall back to the cuda_ext per-quantizer loop. + if isinstance(q, SequentialQuantizer): + return False + if not hasattr(q, "_amax"): + return False + # During calibration each quantizer still needs the cuda_ext path + # so its _amax gets updated; skip Triton until calib finishes. + if getattr(q, "_if_calib", False): + return False + return True + def iter_weights_for_calibration(self): """Yield ``(weight_i, weight_quantizer)`` for each of the ``num_gemms`` grouped weights.""" for i in range(self.num_gemms): @@ -214,9 +314,26 @@ def te_grouped_quantized_linear_fn(package, func_name, self, *args): new_args = list(args) new_args[inp_pos] = self.input_quantizer(args[inp_pos]) - for gemm_idx in range(num_gemms): - pos = weights_start + gemm_idx - new_args[pos] = self._get_weight_quantizer(gemm_idx)(args[pos]) + if self._can_use_triton_per_expert_path(num_gemms): + # Single-launch Triton fakequant for the N expert weights. + # Replaces the per-expert cuda_ext loop (N kernel launches -> 1). + # All per-expert quantizers share the same config (each is a + # copy.deepcopy of the base weight_quantizer in _setup), so + # num_bits/narrow_range/pass_through_bwd read from quantizer 0 + # apply to the whole layer. + weights = list(args[weights_start : weights_start + num_gemms]) + amax_vec = self._gather_per_expert_amax() + q0 = self._get_weight_quantizer(0) + pass_through_bwd = getattr(q0, "_pass_through_bwd", True) + qweights = _GroupedAxis0FakeQuantFn.apply( + amax_vec, q0.num_bits, q0.narrow_range, pass_through_bwd, *weights + ) + for gemm_idx in range(num_gemms): + new_args[weights_start + gemm_idx] = qweights[gemm_idx] + else: + for gemm_idx in range(num_gemms): + pos = weights_start + gemm_idx + new_args[pos] = self._get_weight_quantizer(gemm_idx)(args[pos]) output = getattr(package, func_name)(*new_args) # TE 2.15+ returns `(out, new_workspaces)`; TE <= 2.14 returns just `out`. # Only the activation tensor participates in output quantization. diff --git a/tests/_test_utils/torch/megatron/models.py b/tests/_test_utils/torch/megatron/models.py index 5524634321d..35b98d1c621 100644 --- a/tests/_test_utils/torch/megatron/models.py +++ b/tests/_test_utils/torch/megatron/models.py @@ -145,6 +145,7 @@ def get_mcore_gpt_model( moe_ffn_hidden_size: int | None = None, moe_shared_expert_intermediate_size: int | None = None, num_moe_experts: int | None = None, + sequence_parallel: bool = False, **config_kwargs: dict, ) -> GPTModel: assert activation_func in ["swiglu", "squared_relu"] @@ -168,7 +169,7 @@ def squared_relu(x): pipeline_model_parallel_size=pipeline_model_parallel_size, expert_model_parallel_size=expert_model_parallel_size, expert_tensor_parallel_size=expert_tensor_parallel_size, - sequence_parallel=False, + sequence_parallel=sequence_parallel, num_layers=num_layers, num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage, num_layers_in_last_pipeline_stage=num_layers_in_last_pipeline_stage, diff --git a/tests/gpu/torch/quantization/plugins/test_te_grouped_triton_parity.py b/tests/gpu/torch/quantization/plugins/test_te_grouped_triton_parity.py new file mode 100644 index 00000000000..7bd8116b151 --- /dev/null +++ b/tests/gpu/torch/quantization/plugins/test_te_grouped_triton_parity.py @@ -0,0 +1,151 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""OMNIML-5072 AC2 — ATriton vs A parity on the N-modules per-expert path. + +ATriton = A's `_per_expert_weight_quantizer == True` path with the Triton +fakequant dispatch added in OMNIML-5072 (see +`modelopt/torch/quantization/plugins/transformer_engine.py`). + +A = the same N-modules path but going through `FakeTensorQuantFunction.apply` +per expert (cuda_ext under the hood). + +Two checks at each shape: + + forward parity: max_abs_err <= 1 ULP. Known rounding-mode mismatch — Triton + rounds via `libdevice.rint` while cuda_ext rounds via its own builtin; + both are banker's rounding but disagree on one ULP at a fraction of + bf16 boundary values. + + backward parity (pass_through_bwd=True): bit-exact (max_abs_err == 0.0). + Under modelopt's default pass-through STE, both paths return grad_out + unchanged regardless of the forward kernel — so gradient identity is + required, not approximate. +""" + +import pytest +import torch + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="ATriton parity test requires a CUDA GPU", +) + + +def _parity_atriton_vs_a( + num_experts: int, + out_features: int, + in_features: int, + dtype: torch.dtype, +): + """Run both paths at a given shape; return (fwd_max_err, ulp_floor, bwd_max_err).""" + from modelopt.torch.kernels.quantization.gemm import ( + IS_AVAILABLE, + grouped_axis0_fakequant, + ) + from modelopt.torch.quantization.extensions import get_cuda_ext + from modelopt.torch.quantization.plugins.transformer_engine import ( + _GroupedAxis0FakeQuantFn, + ) + from modelopt.torch.quantization.tensor_quant import FakeTensorQuantFunction + + if not IS_AVAILABLE: + pytest.skip("triton kernels not loaded — IS_AVAILABLE is False") + + cuda_ext = get_cuda_ext() + device = "cuda" + + torch.manual_seed(0) + weights = [ + torch.randn(out_features, in_features, dtype=dtype, device=device) + for _ in range(num_experts) + ] + amax_scalars = [w.abs().amax().to(torch.float32) for w in weights] + amax_vec = torch.stack([a.view(1) for a in amax_scalars]).view(num_experts, 1, 1) + + # ---- forward parity ---- + q_a = [ + cuda_ext.fake_tensor_quant(w, a, 8, False, True) + for w, a in zip(weights, amax_scalars) + ] + q_t = grouped_axis0_fakequant(weights, amax_vec, num_bits=8, narrow_range=True) + + fwd_max_err = 0.0 + for i in range(num_experts): + diff = (q_a[i].float() - q_t[i].float()).abs() + fwd_max_err = max(fwd_max_err, float(diff.max().item())) + + # 1 ULP at this quant scale is amax / 127 (narrow_range, 8-bit). + ulp_floor = float(amax_vec.max().item() / 127.0) + + # ---- backward parity under pass_through_bwd=True ---- + # Use a sum-loss (no GEMM) so any divergence is from the quantizer's + # autograd wrapping, not GEMM determinism. + ws_a = [w.detach().clone().requires_grad_(True) for w in weights] + # FakeTensorQuantFunction signature: + # (inputs, amax, bias=None, num_bits, unsigned, narrow_range, + # trt_high_precision_dtype, pass_through_bwd, block_size, axis) + qs_a = [ + FakeTensorQuantFunction.apply(w, a, None, 8, False, True, None, True, None, None) + for w, a in zip(ws_a, amax_scalars) + ] + loss_a = sum(q.sum() for q in qs_a) + loss_a.backward() + + ws_t = [w.detach().clone().requires_grad_(True) for w in weights] + # _GroupedAxis0FakeQuantFn.apply(amax_vec, num_bits, narrow_range, pass_through_bwd, *weights) + qs_t = _GroupedAxis0FakeQuantFn.apply(amax_vec, 8, True, True, *ws_t) + loss_t = sum(q.sum() for q in qs_t) + loss_t.backward() + + bwd_max_err = 0.0 + for i in range(num_experts): + diff = (ws_a[i].grad.float() - ws_t[i].grad.float()).abs() + bwd_max_err = max(bwd_max_err, float(diff.max().item())) + + return fwd_max_err, ulp_floor, bwd_max_err + + +# Small-to-moderate shapes — fast CI signal across shape regimes. +@pytest.mark.parametrize( + "num_experts,out_features,in_features", + [ + (4, 64, 128), + (8, 128, 256), + (32, 512, 1024), + ], +) +def test_atriton_vs_a_parity(num_experts, out_features, in_features): + """ATriton vs A: fwd within 1 ULP, bwd bit-exact (pass_through_bwd=True).""" + fwd_err, ulp_floor, bwd_err = _parity_atriton_vs_a( + num_experts, out_features, in_features, torch.bfloat16 + ) + assert fwd_err <= ulp_floor + 1e-6, ( + f"fwd_max_abs_err={fwd_err:.6f} > 1-ULP floor {ulp_floor:.6f} " + f"at N={num_experts}, [out, in]=[{out_features}, {in_features}]" + ) + assert bwd_err == 0.0, ( + f"bwd_max_abs_err={bwd_err:.6f} expected 0.0 " + f"at N={num_experts}, [out, in]=[{out_features}, {in_features}]" + ) + + +@pytest.mark.slow +def test_atriton_vs_a_parity_ultra_production_shape(): + """AC2 — Ultra production shape (N=32, [5120, 8192] bf16). + + Marked `slow` because the unquantized + quantized + gradient copies of + 32 expert weights at [5120, 8192] bf16 use about 5 GB of GPU memory. + """ + fwd_err, ulp_floor, bwd_err = _parity_atriton_vs_a( + num_experts=32, + out_features=5120, + in_features=8192, + dtype=torch.bfloat16, + ) + assert fwd_err <= ulp_floor + 1e-6, ( + f"fwd_max_abs_err={fwd_err:.6f} > 1-ULP floor {ulp_floor:.6f} at Ultra shape" + ) + assert bwd_err == 0.0, ( + f"bwd_max_abs_err={bwd_err:.6f} expected 0.0 at Ultra shape" + ) diff --git a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py index 710725af04c..f6fa2a324f9 100644 --- a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py @@ -15,6 +15,7 @@ import copy import math +import os from functools import partial import pytest @@ -291,6 +292,7 @@ def _gpt_model_provider( use_cpu_initialization=meta_device, num_moe_experts=num_moe_experts, moe_grouped_gemm=moe_grouped_gemm, + sequence_parallel=(tp_size > 1), # OMNIML-5030: Required for MoE + TP (mirrors hybrid path) ) if not meta_device: @@ -820,6 +822,64 @@ def test_te_grouped_vs_sequential_default_loss(dist_workers_size_4, quant_cfg): ) +# OMNIML-5072 AC4: A's N-modules-per-expert weight amax should round-trip through +# sharded_state_dict / dist-checkpoint with bit-for-bit equality across EP. Mirror +# of B's test_te_grouped_per_expert_sharded_state_dict layout (dist_workers fixture, +# hidden_size=256, tp=2 ep=2 etp=1 num_moe_experts=4 moe_grouped_gemm=True), but +# A's path is triggered by the MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER env var +# instead of an axis=0 quant_cfg knob. +def _test_te_grouped_n_modules_sharded_state_dict_helper( + tmp_path, config, hidden_size, modelopt_version, compress, meta_device, model_config, rank, size +): + # Per-rank env-var set so _QuantTEGroupedLinear._setup picks up the + # N-modules path. Must precede mtq.quantize() — that's where _setup runs. + os.environ["MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER"] = "1" + try: + _test_sharded_state_dict( + tmp_path, + config, + hidden_size, + modelopt_version, + compress, + meta_device, + model_config, + rank, + size, + ) + finally: + os.environ.pop("MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER", None) + + +@pytest.mark.parametrize("config", [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG]) +def test_te_grouped_n_modules_sharded_state_dict(dist_workers, need_4_gpus, tmp_path, config): + """A's N-modules per-expert weight amax round-trips through dist-checkpoint on TEGroupedMLP. + + OMNIML-5072 AC4. Each rank's `weight_quantizer_i._amax` scalars are gathered + across the EP group at save time, persisted as an `[N_global]` vector, and + narrowed back to per-submodule scalars on load. + """ + moe_config = { + "tp_size": 2, + "ep_size": 2, + "etp_size": 1, + "num_moe_experts": 4, + "moe_grouped_gemm": True, + "transformer_impl": "transformer_engine", + } + dist_workers.run( + partial( + _test_te_grouped_n_modules_sharded_state_dict_helper, + tmp_path, + copy.deepcopy(config), + 256, + None, + False, + False, + moe_config, + ), + ) + + @pytest.mark.parametrize("ep_size", [1, 2]) @pytest.mark.parametrize("sync_weight_amax", [True, False]) def test_layer_sync_moe_local_experts_amax(dist_workers, ep_size, sync_weight_amax):