Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,11 @@ def mse_calibrate(
for name, module in model.named_modules():
if isinstance(module, TensorQuantizer) and not module._disabled:
# Static block quantization is not supported by MseCalibrator
if module.is_static_block_quant:
raise ValueError(
f"MSE calibration does not support static block quantization. "
f"Found static block quantization at {name}."
)
# if module.is_static_block_quant:
# raise ValueError(
# f"MSE calibration does not support static block quantization. "
# f"Found static block quantization at {name}."
# )
if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"):
# Get the initial amax from max calibration
initial_amax = module._amax.clone().detach()
Expand All @@ -237,7 +237,9 @@ def quant_func(x, amax, quantizer=module):
disable_calib(quantizer),
enable_fake_quant(quantizer),
):
quantizer._keep_shape = True
xq = quantizer(x)
quantizer._keep_shape = False

if original_amax is not None:
quantizer._amax = original_amax
Expand Down
19 changes: 13 additions & 6 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(
self._enable_pre_quant_scale = True
self._dequantize = False
self._input_dtype = None
self._keep_shape = False

# Lazy initialize the bias calibrator for KV cache quantization
self._bias_calibrator = None
Expand Down Expand Up @@ -653,6 +654,12 @@ def _fake_quantize(self, inputs):
getattr(self, "_onnx_quantizer_type", None),
self._pass_through_bwd,
)
elif self._num_bits == (2, 1):
from modelopt.torch.quantization.triton.fp4_kernel import (
launch_blockwise_fp4_fake_quant,
)

outputs = launch_blockwise_fp4_fake_quant(inputs, amax / 6.0, out_dtype=inputs.dtype)
elif isinstance(self._num_bits, tuple):
# Float-point quantization, e.g., FP8
E, M = self._num_bits # noqa: N806
Expand Down Expand Up @@ -783,11 +790,11 @@ def _process_for_blockquant(self, inputs: torch.Tensor):
if hasattr(self, "_padding"):
inputs = F.pad(inputs, self._padding, "constant", 0)

if inputs.shape != self._original_shape:
raise ValueError(
f"Input shape has changed from {self._original_shape} to {inputs.shape}."
" Block-quantization requires a fixed input shape."
)
# if inputs.shape != self._original_shape:
# print(
# f"Input shape has changed from {self._original_shape} to {inputs.shape}."
# " Block-quantization requires a fixed input shape."
# )
inputs = inputs.reshape(self._block_reshape_size)
return inputs

Expand Down Expand Up @@ -941,7 +948,7 @@ def forward(self, inputs):
"This case should have been handled."
)

if self.is_static_block_quant:
if self.is_static_block_quant and not self._keep_shape:
outputs = self._reset_to_original_shape(outputs)

return outputs
Expand Down
154 changes: 154 additions & 0 deletions modelopt/torch/quantization/triton/fp4_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,157 @@ def fp4_dequantize(
)

return output


@triton.jit
def blockwise_fp4_fake_quant_kernel(
x_ptr, # [NUM_FP4_BLOCKS * BLOCK_SIZE]
y_ptr, # [NUM_FP4_BLOCKS * BLOCK_SIZE]
scale_ptr, # [NUM_FP4_BLOCKS]
NUM_FP4_BLOCKS,
BLOCK_SIZE: tl.constexpr,
OUT_DTYPE: tl.constexpr,
):
pid = tl.program_id(axis=0)
if pid >= NUM_FP4_BLOCKS:
return

block_offset = pid * BLOCK_SIZE
idx = block_offset + tl.arange(0, BLOCK_SIZE)

scale = tl.load(scale_ptr + pid).to(tl.float32)

x = tl.load(x_ptr + idx).to(tl.float32)

x_abs = tl.abs(x)
scale_safe = tl.where(scale >= 1e-5, scale, 1.0)
abs_scaled = x_abs / scale_safe

# FP4 values: 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0
q_val = tl.where(
abs_scaled <= 0.25,
0.0,
tl.where(
abs_scaled < 0.75,
0.5,
tl.where(
abs_scaled <= 1.25,
1.0,
tl.where(
abs_scaled < 1.75,
1.5,
tl.where(
abs_scaled <= 2.5,
2.0,
tl.where(
abs_scaled < 3.5,
3.0,
tl.where(abs_scaled <= 5.0, 4.0, 6.0),
),
),
),
),
),
)

x_rescaled = q_val * scale_safe
x_dequant = tl.where(x >= 0, x_rescaled, -x_rescaled)

tl.store(y_ptr + idx, x_dequant.to(OUT_DTYPE))


def launch_blockwise_fp4_fake_quant(
x: torch.Tensor,
scale: torch.Tensor,
out_dtype: torch.dtype = torch.float16,
):
"""Launch Triton kernel for blockwise FP4 fake quantization.

x: [NUM_FP4_BLOCKS, BLOCK_SIZE] on CUDA.
scale: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] on CUDA.
"""
assert x.ndim == 2
NUM_FP4_BLOCKS, BLOCK_SIZE = x.shape

x_flat = x.contiguous().view(-1)
y_flat = torch.empty_like(x_flat, dtype=out_dtype)
scale_flat = scale.view(NUM_FP4_BLOCKS).contiguous()

tl_out_dtype = _torch_dtype_to_tl(out_dtype)

grid = (NUM_FP4_BLOCKS,)

# Ensure we're running on the correct CUDA device
with torch.cuda.device(x.device):
blockwise_fp4_fake_quant_kernel[grid](
x_flat,
y_flat,
scale_flat,
NUM_FP4_BLOCKS,
BLOCK_SIZE,
OUT_DTYPE=tl_out_dtype,
)

return y_flat.view_as(x)


def blockwise_fp4_fake_quant_reference(
x: torch.Tensor,
scale: torch.Tensor,
out_dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
"""Reference implementation of blockwise FP4 fake quantization.

x: [NUM_FP4_BLOCKS, BLOCK_SIZE].
scale: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1].

Uses FP4 quantization levels: 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0.
"""
assert x.ndim == 2
num_blocks, block_size = x.shape

if scale.ndim == 1:
scale = scale.view(num_blocks, 1)
assert scale.shape == (num_blocks, 1)

x_f = x.to(torch.float32)
s_f = scale.to(torch.float32)

s_f = torch.where(s_f >= 1e-5, s_f, torch.ones_like(s_f))

x_abs = torch.abs(x_f)
abs_scaled = x_abs / s_f

q_val = torch.where(
abs_scaled <= 0.25,
torch.zeros_like(abs_scaled),
torch.where(
abs_scaled < 0.75,
torch.full_like(abs_scaled, 0.5),
torch.where(
abs_scaled <= 1.25,
torch.ones_like(abs_scaled),
torch.where(
abs_scaled < 1.75,
torch.full_like(abs_scaled, 1.5),
torch.where(
abs_scaled <= 2.5,
torch.full_like(abs_scaled, 2.0),
torch.where(
abs_scaled < 3.5,
torch.full_like(abs_scaled, 3.0),
torch.where(
abs_scaled <= 5.0,
torch.full_like(abs_scaled, 4.0),
torch.full_like(abs_scaled, 6.0),
),
),
),
),
),
),
)

x_rescaled = q_val * s_f
x_dequant = torch.where(x_f >= 0, x_rescaled, -x_rescaled)
return x_dequant.to(out_dtype)