Skip to content
4 changes: 4 additions & 0 deletions modelopt/torch/kernels/quantization/gemm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,8 @@
if torch.cuda.get_device_capability() >= (8, 9):
from .fp4_kernel_hopper import *

# OMNIML-5072 — per-expert axis-0 fake-quant via tensor-of-pointers.
# Generic Triton + CUDA; no special hardware. See VALIDATION_TODO.md.
from .grouped_axis0_fakequant import *

IS_AVAILABLE = True
257 changes: 257 additions & 0 deletions modelopt/torch/kernels/quantization/gemm/grouped_axis0_fakequant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Fused per-expert axis-0 fake-quant Triton kernels for TEGroupedLinear.

Replaces the stack-then-quantize-then-unbind pattern in modelopt's TEGrouped
plugin (`te_grouped_quantized_linear_fn`) with a single Triton launch that
processes N expert weights in place, with no contiguous-tensor staging.

Design — tensor of pointers
---------------------------

The N expert weights live as separate Parameters (one per expert), so they're
NOT contiguous in HBM. To avoid a `torch.stack` memcopy (the cost AC5
characterized on OMNIML-5064), we feed the kernel a `[N]` int64 tensor of
expert base pointers. Each Triton program reads its expert's pointer first,
then strides through a block of elements at that address.

Grid: (N, num_blocks_per_expert).
Program 0 of axis 0 → expert 0, program 1 → expert 1, etc.

See OMNIML-5072 AC5 (Option B follow-up) for the motivation.

VALIDATION STATUS (2026-06-11): kernel implemented, numerical fidelity NOT
yet validated against modelopt's reference `fake_quant_impl`, and bench
performance NOT yet measured. See VALIDATION_TODO.md in this directory.
"""

from __future__ import annotations

import torch
import triton
import triton.language as tl
from triton.language.extra.cuda import libdevice

__all__ = ["grouped_axis0_fakequant", "grouped_axis0_fakequant_backward"]


def _torch_dtype_to_tl(dtype: torch.dtype):
"""Map a torch dtype to its Triton-language equivalent."""
return {
torch.float32: tl.float32,
torch.bfloat16: tl.bfloat16,
torch.float16: tl.float16,
}[dtype]


@triton.jit
def _grouped_axis0_fakequant_fwd_kernel(
weight_ptrs_buf, # int64 [N] — N expert base pointers (cast from .data_ptr())
output_ptrs_buf, # int64 [N] — N output base pointers
amax_vec_ptr, # [N, 1, 1] (or anything with N as the leading dim)
elements_per_expert,
num_bits,
narrow_range: tl.constexpr,
DTYPE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
expert_idx = tl.program_id(axis=0)
block_idx = tl.program_id(axis=1)

# Per-expert base pointers (loaded once per program).
w_int = tl.load(weight_ptrs_buf + expert_idx)
out_int = tl.load(output_ptrs_buf + expert_idx)
w_ptr = w_int.to(tl.pointer_type(DTYPE))
out_ptr = out_int.to(tl.pointer_type(DTYPE))

# Per-expert amax → quant scale.
# amax is stored as fp32; convert to working precision.
amax = tl.load(amax_vec_ptr + expert_idx).to(tl.float32)
# qmax = 2^(num_bits-1) - 1 when narrow_range else 2^(num_bits-1)
# For num_bits=8 narrow_range=True (modelopt default): qmax=127
qmax = ((1 << (num_bits - 1)) - 1) if narrow_range else (1 << (num_bits - 1))
qmin = -qmax if narrow_range else -qmax # signed symmetric
scale = amax / qmax

# Block of elements within this expert.
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < elements_per_expert

x = tl.load(w_ptr + offsets, mask=mask, other=0.0).to(tl.float32)

# Fake-quant: round(clip(x / scale)) * scale.
# Use scale guards to avoid div-by-zero before _amax is calibrated (early
# batches may carry an _amax of 0; matches modelopt's fake_tensor_quant
# behavior of passing through unchanged).
safe_scale = tl.where(scale > 0.0, scale, 1.0)
q = x / safe_scale
q = tl.maximum(tl.minimum(q, qmax), qmin)
# Round-half-to-even (banker's), matching cuda_ext.fake_tensor_quant exactly.
# libdevice.rint is CUDA's __rint* builtin. Imported via the same path that
# modelopt's nvfp4_quant.py uses (triton.language.extra.cuda.libdevice).
q_rounded = libdevice.rint(q)
out = tl.where(scale > 0.0, q_rounded * scale, x)

tl.store(out_ptr + offsets, out.to(DTYPE), mask=mask)


@triton.jit
def _grouped_axis0_fakequant_bwd_kernel(
weight_ptrs_buf, # int64 [N] — same buffer as fwd
grad_out_ptrs_buf, # int64 [N] — upstream grad pointers (per expert)
grad_in_ptrs_buf, # int64 [N] — output: downstream grad pointers
amax_vec_ptr, # [N, ...] — same buffer as fwd
elements_per_expert,
DTYPE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Clip-aware STE backward.

For each expert i:
grad_in[i] = grad_out[i] if |w[i]| <= amax[i] else 0
matches modelopt's `_fake_tensor_quant_backward` semantics.
"""
expert_idx = tl.program_id(axis=0)
block_idx = tl.program_id(axis=1)

w_ptr = tl.load(weight_ptrs_buf + expert_idx).to(tl.pointer_type(DTYPE))
grad_out_ptr = tl.load(grad_out_ptrs_buf + expert_idx).to(tl.pointer_type(DTYPE))
grad_in_ptr = tl.load(grad_in_ptrs_buf + expert_idx).to(tl.pointer_type(DTYPE))

amax = tl.load(amax_vec_ptr + expert_idx).to(tl.float32)

offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < elements_per_expert

# Stay in DTYPE (bf16/fp16) throughout — eliminates fp32 round-trip seen in
# the Btriton2 baseline that capped bwd bandwidth at ~4.2 TB/s vs cuda_ext's
# ~8 TB/s on B300. amax cast to DTYPE once; comparison done in low precision
# (amax values are O(1)-O(10), well within bf16 range).
w = tl.load(w_ptr + offsets, mask=mask, other=0.0)
g = tl.load(grad_out_ptr + offsets, mask=mask, other=0.0)
amax_dt = amax.to(DTYPE)

# Clip-aware STE: pass through gradient where |w| <= amax, else zero.
pass_through = tl.abs(w) <= amax_dt
grad_in = tl.where(pass_through, g, 0.0)

tl.store(grad_in_ptr + offsets, grad_in, mask=mask)


def _build_ptr_buf(tensors: list[torch.Tensor]) -> torch.Tensor:
"""Pack a list of tensors' .data_ptr() into a single int64 tensor on the same device."""
return torch.tensor(
[t.data_ptr() for t in tensors],
dtype=torch.int64,
device=tensors[0].device,
)


def grouped_axis0_fakequant(
weights: list[torch.Tensor],
amax_vec: torch.Tensor,
num_bits: int = 8,
narrow_range: bool = True,
) -> list[torch.Tensor]:
"""Apply per-expert axis-0 fake-quant in a single Triton launch.

Args:
weights: List of N expert weight tensors. Each must have the same shape
`[out, in]` and same dtype.
amax_vec: Per-expert amax buffer of shape `[N, 1, 1]` (or any shape where
element `i` is expert `i`'s amax). dtype should be float32 for
numerical headroom; the kernel casts to fp32 internally.
num_bits: integer bit-width for the fake-quant.
narrow_range: if True, output range is [-qmax, +qmax]; else [-qmax, +qmax-1].
modelopt's default is True.

Returns:
List of N quantized weight tensors, each the same shape and dtype as
the corresponding input.
"""
assert len(weights) >= 1, "grouped_axis0_fakequant requires at least one expert"
N = len(weights)
shape0 = weights[0].shape
dtype0 = weights[0].dtype
device0 = weights[0].device
elements_per_expert = weights[0].numel()
for w in weights[1:]:
assert w.shape == shape0, "all expert weights must share the same shape"
assert w.dtype == dtype0, "all expert weights must share the same dtype"
assert w.device == device0, "all expert weights must share the same device"

outputs = [torch.empty_like(w) for w in weights]

weight_ptrs = _build_ptr_buf(weights)
output_ptrs = _build_ptr_buf(outputs)

# BLOCK_SIZE=2048 was empirically best in the Btriton2 sweep — larger blocks
# (16384 + num_warps=8) regressed both fwd and bwd, likely from worse warp
# occupancy and load coalescing on B300.
BLOCK_SIZE = 2048
num_blocks_per_expert = triton.cdiv(elements_per_expert, BLOCK_SIZE)
grid = (N, num_blocks_per_expert)

with torch.cuda.device(device0):
_grouped_axis0_fakequant_fwd_kernel[grid](
weight_ptrs,
output_ptrs,
amax_vec,
elements_per_expert,
num_bits,
narrow_range=narrow_range,
DTYPE=_torch_dtype_to_tl(dtype0),
BLOCK_SIZE=BLOCK_SIZE,
)

return outputs


def grouped_axis0_fakequant_backward(
weights: list[torch.Tensor],
grad_outputs: list[torch.Tensor],
amax_vec: torch.Tensor,
) -> list[torch.Tensor]:
"""Apply per-expert clip-aware STE backward in a single Triton launch.

Matches modelopt's `_fake_tensor_quant_backward` semantics — gradient
passes through where `|w[i]| <= amax[i]`, else zero.

Args:
weights: List of N expert weight tensors (the original fwd inputs).
grad_outputs: List of N upstream gradients, one per expert.
amax_vec: Per-expert amax buffer (same shape as in fwd).

Returns:
List of N downstream gradients, one per expert.
"""
N = len(weights)
assert len(grad_outputs) == N
shape0 = weights[0].shape
dtype0 = weights[0].dtype
device0 = weights[0].device
elements_per_expert = weights[0].numel()

grad_inputs = [torch.empty_like(w) for w in weights]

weight_ptrs = _build_ptr_buf(weights)
grad_out_ptrs = _build_ptr_buf(grad_outputs)
grad_in_ptrs = _build_ptr_buf(grad_inputs)

BLOCK_SIZE = 2048
num_blocks_per_expert = triton.cdiv(elements_per_expert, BLOCK_SIZE)
grid = (N, num_blocks_per_expert)

with torch.cuda.device(device0):
_grouped_axis0_fakequant_bwd_kernel[grid](
weight_ptrs,
grad_out_ptrs,
grad_in_ptrs,
amax_vec,
elements_per_expert,
DTYPE=_torch_dtype_to_tl(dtype0),
BLOCK_SIZE=BLOCK_SIZE,
)

return grad_inputs
93 changes: 92 additions & 1 deletion modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
)
from modelopt.torch.utils.distributed import ParallelState

from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer
from ..nn import QuantModule, QuantModuleRegistry, SequentialQuantizer, TensorQuantizer
from ..nn.modules.quant_linear import RealQuantLinear
from ..qtensor import QTensorWrapper
from ..utils import sync_moe_expert_amax
Expand Down Expand Up @@ -638,6 +638,62 @@ class _QuantTELayerNormColumnParallelLinear(

# Quantized subclasses to support TEGroupedMLP quantization
class _QuantMegatronTEGroupedLinear(_QuantTEGroupedLinear, _MegatronParallelLinear):
def _ep_group(self):
# Return the expert_model_parallel_group iff it's initialized AND has >1 rank.
# _MegatronTEGroupedMLP._setup populates parallel_state with the EP group;
# outside that wrapping it may be unset (e.g. ad-hoc unit tests).
ps = getattr(self, "parallel_state", None)
if ps is None:
return None
ep = ps.expert_model_parallel_group
if not ep.is_initialized() or ep.world_size() <= 1:
return None
return ep

def _gather_global_per_expert_amax_n_modules(self):
# Per-expert N-modules path (OMNIML-5072 AC4): each rank's N local
# submodules (weight_quantizer_0..weight_quantizer_{N-1}) carry one
# scalar _amax each. Stack them into [N_local] and all-gather across
# the EP group to get the global [N_global] vector.
#
# Returns None when not in per-expert mode or any quantizer lacks
# _amax (caller falls back to the legacy per-quantizer path).
if not getattr(self, "_per_expert_weight_quantizer", False):
return None
local = []
for i in range(self.num_gemms):
q = self._get_weight_quantizer(i)
q_inner = q[0] if isinstance(q, SequentialQuantizer) else q
amax = getattr(q_inner, "_amax", None)
if amax is None:
return None
local.append(amax.view(()).to(torch.float32))
v_local = torch.stack(local) # [N_local]
ep = self._ep_group()
if ep is None:
return v_local # EP=1: local IS global.
gathered = [torch.empty_like(v_local) for _ in range(ep.world_size())]
torch.distributed.all_gather(gathered, v_local, group=ep.group)
return torch.cat(gathered, dim=0) # [N_global]

def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
# Gather the global per-expert amax ONCE before Megatron's save traversal.
# Stash on a temporary attribute that _process_quantizer_amax reads; the
# live per-expert weight_quantizer_i._amax scalars stay untouched so
# forward keeps working.
#
# Done at the top of sharded_state_dict (not inside _process_quantizer_amax)
# so the EP collective completes BEFORE Megatron's dist-checkpoint save
# kicks off its own default-PG ALLGATHER metadata exchanges. Interleaving
# EP gathers with default-PG collectives deadlocks NCCL.
self._cached_global_per_expert_amax_n_modules = (
self._gather_global_per_expert_amax_n_modules()
)
try:
return super().sharded_state_dict(prefix, sharded_offsets, metadata)
finally:
self._cached_global_per_expert_amax_n_modules = None

def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
# _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in
# sharded_state_dict which is same as _extra_state. The _extra_state{gemm_idx} is used for
Expand All @@ -648,9 +704,44 @@ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
for k, v in state_dict.items()
if not any(k.endswith(f"_extra_state{num}") for num in range(1, self.num_gemms))
}
# Per-expert N-modules path: each weight_quantizer_i._amax key in the
# saved state-dict carries the gathered [N_global] vector (replicated
# across all EP ranks). Narrow it to this rank's local scalar — pull
# element (ep_rank * N_local + i) out of the global tensor and
# reshape to scalar for the i-th submodule's _amax buffer.
if getattr(self, "_per_expert_weight_quantizer", False):
import re
ep = self._ep_group()
ep_rank = ep.rank() if ep is not None else 0
ep_size = ep.world_size() if ep is not None else 1
global_size = self.num_gemms * ep_size
offset = ep_rank * self.num_gemms
pattern = re.compile(r"weight_quantizer_(\d+)\._amax$")
for k in list(filtered_state_dict.keys()):
m = pattern.search(k)
if m is None:
continue
local_i = int(m.group(1))
v = filtered_state_dict[k]
if v.numel() == global_size:
filtered_state_dict[k] = (
v.view(global_size)[offset + local_i].view(())
)
# else: leave as-is (legacy or EP=1 save format).
return super()._load_from_state_dict(filtered_state_dict, prefix, *args, **kwargs)

def _process_quantizer_amax(self, k, v, quantizer_state_dict):
# Per-expert N-modules path: emit the gathered global per-expert
# amax vector under every weight_quantizer_i._amax key (replicated
# across EP ranks). On load, _load_from_state_dict narrows back to
# the local scalar per submodule. Suboptimal disk usage (N copies
# of the same vector per layer) but mirrors B's gather-once-cache
# pattern and avoids surgery into the base-class state-dict
# iteration.
cached = getattr(self, "_cached_global_per_expert_amax_n_modules", None)
if cached is not None and "weight_quantizer_" in k and k.endswith("_amax"):
quantizer_state_dict[k] = cached.view(cached.numel())
return
assert v.numel() == 1, "TEGroupedLinear only supports per-tensor quantization"
quantizer_state_dict[k] = v.view(-1)

Expand Down
Loading