Skip to content

Commit a3cba2b

Browse files
szyszyzysfacebook-github-bot
authored andcommitted
Move codebook (LUT) generation methods into common utils. Update functions be more compatible with coreml. (#2772)
Summary: Pull Request resolved: #2772 Reviewed By: metascroy Differential Revision: D79595460
1 parent 8812365 commit a3cba2b

File tree

8 files changed

+940
-706
lines changed

8 files changed

+940
-706
lines changed

test/prototype/test_codebook_coreml.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,41 @@ def test_quantize_api(self):
7575
)
7676
assert type(m[0].weight) == CodebookQuantizedTensor
7777

78+
def test_choose_qparams_codebook_row_grouping(self):
79+
# Test with a block_size that forces row-wise grouping: [10, 256]
80+
# Input tensor is (100, 256)
81+
row_grouped_block_size = [10, -1]
82+
num_row_groups = (
83+
self.input.shape[0] // row_grouped_block_size[0]
84+
) # 100 // 10 = 10
85+
86+
codebook, wq = choose_qparams_and_quantize_codebook_coreml(
87+
self.input,
88+
self.code_dtype,
89+
row_grouped_block_size,
90+
)
91+
92+
# Expected shape for row-wise grouping is (num_row_groups, 1, 2**nbits, 1)
93+
self.assertEqual(codebook.shape, (num_row_groups, 1, 2**self.nbits, 1))
94+
self.assertEqual(wq.shape, (100, 256))
95+
96+
self.assertFalse(torch.isnan(codebook).any())
97+
self.assertFalse(torch.isnan(wq).any())
98+
99+
def test_codebook_quantized_tensor_from_float_row_grouping(self):
100+
# Test end-to-end quantization/dequantization with row grouping
101+
row_grouped_block_size = [20, -1] # 100 is divisible by 20
102+
cqt = CodebookQuantizedTensor.from_float(
103+
self.input,
104+
self.code_dtype,
105+
row_grouped_block_size,
106+
)
107+
108+
dequant = cqt.dequantize()
109+
# The SQNR will be different from column grouping, but should still be high
110+
sqnr = compute_error(dequant, self.input)
111+
self.assertGreater(sqnr, 30)
112+
78113
def test_export(self):
79114
m = torch.nn.Sequential(torch.nn.Linear(128, 64)).to(torch.float32)
80115
quantize_(m, CodebookWeightOnlyConfig(self.code_dtype, self.block_size))
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
import tempfile
9+
import unittest
10+
11+
import torch
12+
import torch.nn as nn
13+
from parameterized import param, parameterized
14+
from torch import uint1, uint2, uint3, uint4
15+
16+
from torchao.prototype.quantization.codebook_groupwise.api import (
17+
GroupwiseLutWeightConfig,
18+
)
19+
from torchao.prototype.quantization.codebook_utils.codebook_utils import (
20+
group_size_to_block_shapes,
21+
)
22+
from torchao.quantization.quant_api import quantize_
23+
24+
25+
class TestGroupwiseLowbitWeightLut(unittest.TestCase):
26+
"""
27+
Test suite for the GroupwiseLutWeight quantization scheme, updated for the
28+
new simplified API.
29+
"""
30+
31+
TEST_CASES = [
32+
param(
33+
code_dtype=code_dtype,
34+
lut_group_size=lut_group_size,
35+
weight_dtype=weight_dtype,
36+
has_bias=has_bias,
37+
)
38+
for code_dtype in [uint1, uint2, uint3, uint4]
39+
for lut_group_size in [256, 512]
40+
for weight_dtype in [torch.float32]
41+
for has_bias in [True, False]
42+
]
43+
44+
# --------------------------------------------------------------------------
45+
# Test 1: End-to-End Model Accuracy
46+
# --------------------------------------------------------------------------
47+
@parameterized.expand(TEST_CASES)
48+
def test_e2e_accuracy_vs_reference(
49+
self,
50+
code_dtype,
51+
lut_group_size,
52+
weight_dtype,
53+
has_bias,
54+
):
55+
"""
56+
Tests the numerical accuracy of the full quantized model against a reference.
57+
This now uses the `use_qdq_reference` flag instead of layout objects.
58+
"""
59+
m, k, n = 3, 64, 32
60+
activations = torch.randn(m, k, dtype=weight_dtype)
61+
model = nn.Sequential(nn.Linear(k, n, bias=has_bias)).to(dtype=weight_dtype)
62+
63+
# --- 2. Update tensor_shape to reflect the new (k, n) layout ---
64+
lut_block_shape = group_size_to_block_shapes(
65+
lut_group_size=lut_group_size, tensor_shape=(n, k)
66+
)
67+
68+
# --- Quantize using C++ ops ---
69+
quantized_model = copy.deepcopy(model)
70+
perf_config = GroupwiseLutWeightConfig(
71+
code_dtype=code_dtype,
72+
weight_dtype=weight_dtype,
73+
lut_block_shape=lut_block_shape,
74+
use_qdq_reference=False,
75+
)
76+
quantize_(quantized_model, perf_config)
77+
with torch.no_grad():
78+
actual_result = quantized_model(activations)
79+
80+
# --- Quantize for Reference (using Python ops) ---
81+
reference_model = copy.deepcopy(model)
82+
ref_config = GroupwiseLutWeightConfig(
83+
code_dtype=code_dtype,
84+
weight_dtype=weight_dtype,
85+
lut_block_shape=lut_block_shape,
86+
use_qdq_reference=True,
87+
)
88+
quantize_(reference_model, ref_config)
89+
with torch.no_grad():
90+
expected_result = reference_model(activations)
91+
# Compare results
92+
self.assertTrue(
93+
torch.allclose(actual_result, expected_result, atol=1e-2, rtol=1e-2)
94+
)
95+
96+
def tearDown(self):
97+
"""
98+
Clear the TorchDynamo cache after each test case to prevent
99+
recompilation errors in parameterized tests.
100+
"""
101+
super().tearDown()
102+
torch._dynamo.reset()
103+
104+
# --------------------------------------------------------------------------
105+
# Test 2: Deployment Readiness (Updated for new API)
106+
# --------------------------------------------------------------------------
107+
@parameterized.expand(TEST_CASES)
108+
def test_export_compile_aoti(
109+
self,
110+
code_dtype,
111+
lut_group_size,
112+
weight_dtype,
113+
has_bias,
114+
):
115+
"""
116+
Tests that the quantized model can be exported and compiled.
117+
"""
118+
k, n = 64, 32
119+
activations = torch.randn(2, k, dtype=weight_dtype)
120+
model = (
121+
nn.Sequential(nn.Linear(k, n, bias=has_bias)).to(dtype=weight_dtype).eval()
122+
)
123+
lut_block_shape = group_size_to_block_shapes(
124+
lut_group_size=lut_group_size,
125+
tensor_shape=(n, k),
126+
)
127+
128+
# Configure the quantization using the new API
129+
config = GroupwiseLutWeightConfig(
130+
code_dtype=code_dtype,
131+
weight_dtype=weight_dtype,
132+
lut_block_shape=lut_block_shape,
133+
use_qdq_reference=False,
134+
)
135+
quantize_(model, config)
136+
137+
with torch.no_grad():
138+
eager_results = model(activations)
139+
140+
# Export and Compile
141+
exported_model = torch.export.export(model, (activations,))
142+
compiled_model = torch.compile(model, fullgraph=True)
143+
144+
with tempfile.TemporaryDirectory() as tmpdir, torch.no_grad():
145+
# Check exported model
146+
exported_results = exported_model.module()(activations)
147+
self.assertTrue(
148+
torch.allclose(eager_results, exported_results, atol=1e-3, rtol=1e-3)
149+
)
150+
151+
# Check compiled model
152+
compiled_results = compiled_model(activations)
153+
self.assertTrue(
154+
torch.allclose(eager_results, compiled_results, atol=1e-3, rtol=1e-3)
155+
)
156+
157+
# Check AOTI compiled model using the packaging API
158+
package_path = f"{tmpdir}/model.pt2"
159+
torch._inductor.aoti_compile_and_package(
160+
exported_model, package_path=package_path
161+
)
162+
aoti_model = torch._inductor.aoti_load_package(package_path)
163+
aoti_results = aoti_model(activations)
164+
self.assertTrue(
165+
torch.allclose(eager_results, aoti_results, atol=1e-3, rtol=1e-3)
166+
)
167+
168+
169+
if __name__ == "__main__":
170+
unittest.main()

torchao/prototype/quantization/codebook_coreml/codebook_ops.py

Lines changed: 66 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -57,65 +57,83 @@ def choose_qparams_and_quantize_codebook_coreml(
5757
assert code_dtype in list(_SUB_BYTE_UINT_BOUNDS.keys()) + [torch.uint8]
5858
nbits = _DTYPE_TO_BIT_WIDTH[code_dtype]
5959
assert nbits >= 1 and nbits <= 8, f"nbits must be in [1, 8], got {nbits}"
60-
61-
assert len(block_size) == input_tensor.dim()
62-
block_size = block_size.copy()
63-
for i in range(len(block_size)):
64-
if block_size[i] == -1:
65-
block_size[i] = input_tensor.shape[i]
66-
assert block_size[i] >= 1 and input_tensor.shape[i] % block_size[i] == 0, (
67-
"block_size[i] must divide input_tensor.shape[i]"
68-
)
69-
7060
assert input_tensor.dim() == 2, "Currently only rank 2 tensors are supported"
71-
assert block_size[0] == input_tensor.shape[0], (
72-
"Currently only support per-grouped channel granularity"
73-
)
7461
assert cluster_dim == 1, (
7562
f"only cluster_dim == 1 is supported right now, got {cluster_dim}"
7663
)
7764

78-
num_lut = input_tensor.shape[1] // block_size[1]
79-
group_size = block_size[1]
80-
81-
# for converting to numpy
82-
input_tensor = input_tensor.detach()
8365
original_shape = input_tensor.shape
66+
N, K = original_shape
67+
input_tensor = input_tensor.detach()
8468

85-
# reshape to (N, K // group_size, group_size)
86-
input_tensor = input_tensor.reshape(input_tensor.shape[0], num_lut, group_size)
87-
from coremltools.models.neural_network.quantization_utils import (
88-
_get_kmeans_lookup_table_and_weight,
69+
# --- Process block_size ---
70+
assert len(block_size) == 2
71+
processed_block_size = block_size.copy()
72+
if processed_block_size[0] == -1:
73+
processed_block_size[0] = N
74+
if processed_block_size[1] == -1:
75+
processed_block_size[1] = K
76+
77+
row_block_size, col_block_size = processed_block_size
78+
assert N % row_block_size == 0, (
79+
f"Tensor rows ({N}) not divisible by row block size ({row_block_size})"
80+
)
81+
assert K % col_block_size == 0, (
82+
f"Tensor cols ({K}) not divisible by col block size ({col_block_size})"
8983
)
9084

91-
res_lut = []
92-
# each res_w[:, i, :] will use the same lookup table
93-
# res_w: (N, K // group_size, group_size)
94-
res_w = torch.zeros_like(input_tensor, dtype=torch.uint8)
95-
for i in range(num_lut):
96-
# lut: (2**nbits, 1)
97-
# w: (N * group_size)
98-
lut, w = _get_kmeans_lookup_table_and_weight(
99-
nbits, input_tensor[:, i, :], force_kmeans1d, cluster_dim, vector_axis
100-
)
101-
res_lut.append(torch.from_numpy(lut))
102-
res_w[:, i, :] = torch.from_numpy(w.reshape(input_tensor.shape[0], group_size))
103-
104-
# directly stack all lookup tables along dim 0
105-
# res_lut: (K // group_size, 2 ** nbits)
106-
res_lut = torch.stack(res_lut, dim=0)
107-
108-
# The final LUT should have dimension equal to input_tensor.dim() + 2
109-
# The first input_tensor.dim() dimensions index over the tables,
110-
# input_tensor.dim() + 1 indexes over the nbit indices
111-
# input_tensor.dim() + 2 are the look up values (shape = 1 for scalar)
112-
# res_lut: (N, K // group_size, 2 ** nbits, group_size)
113-
res_lut = res_lut.reshape(1, num_lut, 2**nbits, 1)
85+
# --- Determine and execute grouping strategy ---
86+
assert row_block_size == N or col_block_size == K
87+
is_col_grouping = row_block_size == N
11488

115-
# reshape back to (N, K)
116-
res_w = res_w.reshape(*original_shape)
89+
res_lut_list = []
90+
from coremltools.models.neural_network.quantization_utils import (
91+
_get_kmeans_lookup_table_and_weight,
92+
)
11793

118-
return res_lut, res_w
94+
if is_col_grouping:
95+
# STRATEGY 1: Group by COLUMNS
96+
num_luts = K // col_block_size
97+
reshaped_tensor = input_tensor.reshape(N, num_luts, col_block_size)
98+
res_codes = torch.zeros_like(reshaped_tensor, dtype=torch.uint8)
99+
100+
for i in range(num_luts):
101+
block_to_quantize = reshaped_tensor[:, i, :]
102+
lut, w = _get_kmeans_lookup_table_and_weight(
103+
nbits, block_to_quantize, force_kmeans1d, cluster_dim, vector_axis
104+
)
105+
res_lut_list.append(torch.from_numpy(lut))
106+
res_codes[:, i, :] = torch.from_numpy(w.reshape(N, col_block_size))
107+
108+
# Shape to match CoreML spec: (1, num_luts, 2**nbits, 1)
109+
final_luts = torch.stack(res_lut_list, dim=0).reshape(1, num_luts, 2**nbits, 1)
110+
111+
else: # is_row_grouping
112+
# STRATEGY 2: Group by ROWS
113+
num_luts = N // row_block_size
114+
reshaped_tensor = input_tensor.reshape(num_luts, row_block_size, K)
115+
res_codes = torch.zeros_like(reshaped_tensor, dtype=torch.uint8)
116+
117+
for i in range(num_luts):
118+
block_to_quantize = reshaped_tensor[i, :, :]
119+
lut, w = _get_kmeans_lookup_table_and_weight(
120+
nbits, block_to_quantize, force_kmeans1d, cluster_dim, vector_axis
121+
)
122+
res_lut_list.append(torch.from_numpy(lut))
123+
res_codes[i, :, :] = torch.from_numpy(w.reshape(row_block_size, K))
124+
125+
final_luts_stacked = torch.stack(
126+
res_lut_list, dim=0
127+
) # Shape: (num_luts, 2**nbits, 1)
128+
129+
# Reshape to the consistent 4D format
130+
# The shape is (num_row_groups, 1, 2**nbits, 1)
131+
final_luts = final_luts_stacked.reshape(num_luts, 1, 2**nbits, 1)
132+
133+
# Reshape codes back to the original tensor shape
134+
final_codes = res_codes.reshape(*original_shape)
135+
136+
return final_luts, final_codes
119137

120138

121139
@register_custom_op
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .api import GroupwiseLutWeightConfig
2-
from .codebook_quantized_tensor import GroupwiseLutQuantizedTensor
2+
from .codebook_quantized_tensor import CodebookQuantizedPackedTensor
33

4-
__all__ = ["GroupwiseLutQuantizedTensor", "GroupwiseLutWeightConfig"]
4+
__all__ = ["CodebookQuantizedPackedTensor", "GroupwiseLutWeightConfig"]

0 commit comments

Comments
 (0)