Skip to content

Commit 5352b1c

Browse files
author
Avishek Goswami
committed
Add MSE vs MinMax observer comparison tests
- Add comprehensive test suite comparing MSE and MinMax observers - Test on random tensors with various distributions - Test on real model weights from transformers - Add 'slow' pytest marker to pyproject.toml for long-running tests Signed-off-by: Avishek Goswami <[email protected]>
1 parent 6cf8d29 commit 5352b1c

File tree

2 files changed

+310
-0
lines changed

2 files changed

+310
-0
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,6 @@ markers = [
2323
"unit: tests to ensure code correctness and regression test functionality",
2424
"example: tests for content in the 'examples' folder",
2525
"multi_gpu: tests that require multiple GPUs",
26+
"slow: tests that take a long time to run (e.g., downloading models)",
2627
]
2728
tmp_path_retention_policy = "failed"
Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
"""
2+
Test to verify that MSE observer performs equal to or better than MinMax observer
3+
on various tensor distributions, including normal distributions (similar to real weights)
4+
and actual model weights.
5+
6+
This test checks that the quantization error (MSE) from using MSE observer
7+
is less than or equal to the error from using MinMax observer.
8+
"""
9+
10+
import pytest
11+
import torch
12+
from compressed_tensors.quantization import fake_quantize
13+
from compressed_tensors.quantization.quant_args import QuantizationArgs
14+
15+
from llmcompressor.observers import Observer
16+
17+
18+
def _create_base_quantization_args(num_bits, strategy, symmetric, group_size):
19+
"""Helper to create base QuantizationArgs without observer field."""
20+
return QuantizationArgs(
21+
num_bits=num_bits,
22+
strategy=strategy,
23+
symmetric=symmetric,
24+
group_size=group_size,
25+
)
26+
27+
28+
def _run_observer_test(tensor, observer_name, strategy, symmetric, num_bits, group_size, module=None):
29+
"""
30+
Helper function to run observer and compute quantization error.
31+
32+
Returns: (scale, zero_point, quantized_tensor, mse, global_scale)
33+
"""
34+
weights = _create_base_quantization_args(num_bits, strategy, symmetric, group_size)
35+
weights.observer = observer_name
36+
37+
observer = Observer.load_from_registry(
38+
observer_name, base_name="weight", args=weights, module=module
39+
)
40+
41+
global_scale = None
42+
if strategy == "tensor_group" and module is not None:
43+
global_scale = observer.get_global_scale(tensor)
44+
module.weight_global_scale = global_scale
45+
46+
scale, zero_point = observer(tensor)
47+
48+
# Sanity check: scales should be non-negative
49+
assert (scale >= 0).all(), "Scale values should be non-negative"
50+
51+
weights_clean = _create_base_quantization_args(num_bits, strategy, symmetric, group_size)
52+
quantized = fake_quantize(
53+
tensor, scale, zero_point, weights_clean,
54+
global_scale=global_scale if strategy == "tensor_group" else None
55+
)
56+
mse = torch.nn.functional.mse_loss(quantized, tensor)
57+
58+
return scale, zero_point, quantized, mse, global_scale
59+
60+
61+
def _assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_weights=False):
62+
"""
63+
Assert MSE observer performance with appropriate slack.
64+
65+
For tensor+symmetric: strict assertion (MSE should be better)
66+
For others: allow slack (10% for synthetic, 20% for real weights)
67+
Also add epsilon to handle cases where minmax_mse is near 0.
68+
"""
69+
epsilon = 1e-8
70+
slack = 1.20 if is_real_weights else 1.10
71+
72+
if strategy == "tensor" and symmetric:
73+
# Cases where MSE SHOULD be better
74+
assert mse_mse <= minmax_mse + epsilon, (
75+
f"MSE observer performed worse than MinMax observer!\n"
76+
f"Strategy: {strategy}, Symmetric: {symmetric}\n"
77+
f"MinMax MSE: {minmax_mse.item():.6e}\n"
78+
f"MSE Observer MSE: {mse_mse.item():.6e}\n"
79+
f"Difference: {(mse_mse - minmax_mse).item():.6e}"
80+
)
81+
else:
82+
# Not guaranteed, but ensure not catastrophically worse
83+
assert mse_mse <= minmax_mse * slack + epsilon, (
84+
f"MSE observer performed significantly worse than MinMax observer!\n"
85+
f"Strategy: {strategy}, Symmetric: {symmetric}\n"
86+
f"MinMax MSE: {minmax_mse.item():.6e}\n"
87+
f"MSE Observer MSE: {mse_mse.item():.6e}\n"
88+
f"Difference: {(mse_mse - minmax_mse).item():.6e}\n"
89+
f"Ratio: {(mse_mse / (minmax_mse + epsilon)).item():.4f}x"
90+
)
91+
92+
93+
@pytest.mark.parametrize(
94+
"strategy,symmetric,num_bits",
95+
[
96+
("tensor", True, 8),
97+
("tensor", False, 8),
98+
("channel", True, 8),
99+
("channel", False, 8),
100+
("tensor_group", True, 4),
101+
("tensor_group", False, 4),
102+
("channel", True, 4),
103+
("channel", False, 4),
104+
],
105+
)
106+
@pytest.mark.parametrize(
107+
"std",
108+
[0.05, 0.2, 1.0],
109+
ids=["narrow", "medium", "wide"],
110+
)
111+
def test_mse_vs_minmax_on_random_tensor(strategy, symmetric, num_bits, std):
112+
"""
113+
Test that MSE observer produces quantization error <= MinMax observer
114+
on random tensors with normal distribution (similar to real model weights).
115+
116+
Real model weights typically follow a normal distribution with:
117+
- Mean near 0
118+
- Standard deviation around 0.02-0.1 for initialized weights
119+
- Range roughly [-0.5, 0.5] for most layers
120+
121+
Testing with different std values exposes cases where MinMax performs poorly
122+
on wide or heavy-tailed distributions, where MSE should shine.
123+
"""
124+
# Generate random tensor with normal distribution similar to real weights
125+
torch.manual_seed(42)
126+
# Use different std values to test various distribution widths
127+
tensor = torch.randn(128, 256) * std # Normal distribution with specified std
128+
129+
group_size = 32 if strategy == "tensor_group" else None
130+
131+
# Create separate modules for tensor_group to avoid shared mutable state
132+
module_minmax = None
133+
module_mse = None
134+
if strategy == "tensor_group":
135+
module_minmax = torch.nn.Linear(256, 128)
136+
module_minmax.weight.data = tensor.T
137+
module_mse = torch.nn.Linear(256, 128)
138+
module_mse.weight.data = tensor.T
139+
140+
# Test with MinMax observer
141+
_, _, _, minmax_mse, _ = _run_observer_test(
142+
tensor, "memoryless_minmax", strategy, symmetric, num_bits, group_size, module_minmax
143+
)
144+
145+
# Test with MSE observer
146+
_, _, _, mse_mse, _ = _run_observer_test(
147+
tensor, "memoryless_mse", strategy, symmetric, num_bits, group_size, module_mse
148+
)
149+
150+
# Assert with appropriate slack for synthetic data
151+
_assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_weights=False)
152+
153+
154+
@pytest.mark.parametrize(
155+
"tensor_shape",
156+
[
157+
(64, 128),
158+
(128, 256),
159+
(256, 512),
160+
(32, 64, 128), # 3D tensor
161+
],
162+
)
163+
def test_mse_vs_minmax_various_shapes(tensor_shape):
164+
"""
165+
Test MSE vs MinMax on tensors of various shapes with normal distribution.
166+
Uses realistic weight distribution parameters.
167+
"""
168+
torch.manual_seed(42)
169+
# Use realistic weight distribution: mean=0, std=0.05
170+
tensor = torch.randn(*tensor_shape) * 0.05
171+
172+
# MinMax
173+
_, _, _, minmax_mse, _ = _run_observer_test(
174+
tensor, "memoryless_minmax", "channel", True, 8, None, None
175+
)
176+
177+
# MSE
178+
_, _, _, mse_mse, _ = _run_observer_test(
179+
tensor, "memoryless_mse", "channel", True, 8, None, None
180+
)
181+
182+
# Channel quantization: MSE not guaranteed better, allow 10% slack
183+
_assert_mse_comparison(mse_mse, minmax_mse, "channel", True, is_real_weights=False)
184+
185+
186+
def test_mse_vs_minmax_extreme_values():
187+
"""Test MSE vs MinMax on tensors with extreme values."""
188+
torch.manual_seed(42)
189+
190+
# Test with very small values
191+
tensor_small = torch.randn(64, 128) * 0.01
192+
# Test with very large values
193+
tensor_large = torch.randn(64, 128) * 100.0
194+
# Test with skewed distribution
195+
tensor_skewed = torch.cat([
196+
torch.randn(64, 100) * 0.1,
197+
torch.randn(64, 28) * 10.0
198+
], dim=1)
199+
200+
for tensor, name in [
201+
(tensor_small, "small"),
202+
(tensor_large, "large"),
203+
(tensor_skewed, "skewed"),
204+
]:
205+
weights = QuantizationArgs(
206+
num_bits=8,
207+
strategy="channel",
208+
symmetric=True,
209+
observer="memoryless_minmax",
210+
)
211+
212+
# MinMax
213+
_, _, _, minmax_mse, _ = _run_observer_test(
214+
tensor, "memoryless_minmax", "channel", True, 8, None, None
215+
)
216+
217+
# MSE
218+
_, _, _, mse_mse, _ = _run_observer_test(
219+
tensor, "memoryless_mse", "channel", True, 8, None, None
220+
)
221+
222+
# Channel quantization: MSE not guaranteed better, allow 10% slack
223+
_assert_mse_comparison(mse_mse, minmax_mse, "channel", True, is_real_weights=False)
224+
225+
226+
@pytest.mark.slow
227+
@pytest.mark.parametrize(
228+
"strategy,symmetric,num_bits",
229+
[
230+
("channel", True, 8),
231+
("channel", False, 8),
232+
("tensor_group", True, 4),
233+
("tensor_group", False, 4),
234+
],
235+
)
236+
def test_mse_vs_minmax_on_real_model_weights(strategy, symmetric, num_bits):
237+
"""
238+
Test that MSE observer produces quantization error <= MinMax observer
239+
on actual model weights from a real neural network.
240+
241+
This test loads weights from a small model to verify observer behavior
242+
on real weight distributions, which may differ from synthetic data.
243+
"""
244+
try:
245+
from transformers import AutoModelForCausalLM
246+
except ImportError:
247+
pytest.skip("transformers not available")
248+
249+
# Use a small, publicly available model for testing
250+
model_id = "nm-testing/tinysmokellama-3.2"
251+
252+
try:
253+
# Load model and extract a weight tensor
254+
# Use no_grad context to avoid unnecessary gradient computation
255+
with torch.no_grad():
256+
model = AutoModelForCausalLM.from_pretrained(
257+
model_id, torch_dtype=torch.float32
258+
)
259+
260+
# Get a representative weight tensor (e.g., from first Linear layer)
261+
weight_tensor = None
262+
for name, module in model.named_modules():
263+
if isinstance(module, torch.nn.Linear) and weight_tensor is None:
264+
weight_tensor = module.weight.data.clone()
265+
break
266+
267+
if weight_tensor is None:
268+
pytest.skip("No Linear layer found in model")
269+
270+
# Flatten or reshape to 2D if needed for testing
271+
if weight_tensor.dim() > 2:
272+
weight_tensor = weight_tensor.view(-1, weight_tensor.shape[-1])
273+
elif weight_tensor.dim() == 1:
274+
weight_tensor = weight_tensor.unsqueeze(0)
275+
276+
# Limit size for faster testing
277+
if weight_tensor.shape[0] > 512:
278+
weight_tensor = weight_tensor[:512, :]
279+
if weight_tensor.shape[1] > 512:
280+
weight_tensor = weight_tensor[:, :512]
281+
282+
except Exception as e:
283+
pytest.skip(f"Could not load model {model_id}: {e}")
284+
285+
group_size = 32 if strategy == "tensor_group" else None
286+
287+
# Create separate modules for tensor_group to avoid shared mutable state
288+
module_minmax = None
289+
module_mse = None
290+
if strategy == "tensor_group":
291+
module_minmax = torch.nn.Linear(weight_tensor.shape[1], weight_tensor.shape[0])
292+
module_minmax.weight.data = weight_tensor.T
293+
module_mse = torch.nn.Linear(weight_tensor.shape[1], weight_tensor.shape[0])
294+
module_mse.weight.data = weight_tensor.T
295+
296+
# Test with MinMax observer
297+
_, _, _, minmax_mse, _ = _run_observer_test(
298+
weight_tensor, "memoryless_minmax", strategy, symmetric, num_bits, group_size, module_minmax
299+
)
300+
301+
# Test with MSE observer
302+
_, _, _, mse_mse, _ = _run_observer_test(
303+
weight_tensor, "memoryless_mse", strategy, symmetric, num_bits, group_size, module_mse
304+
)
305+
306+
# For channel and tensor_group strategies, MSE is not guaranteed to be better
307+
# Allow 20% slack for real model weights (more structure & extreme channels)
308+
_assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_weights=True)
309+

0 commit comments

Comments
 (0)