Skip to content

Commit b55713a

Browse files
authored
Add PerBlock to safe globals (#3370)
**Summary:** Add PerBlock to safe globals so users don't have to do this themselves when they load config.json with PerBlock. ``` WeightsUnpickler error: Unsupported global: GLOBAL torchao.quantization.granularity.PerBlock was not an allowed global by default. Please use `torch.serialization.add_safe_globals([torchao.quantization.granularity.PerBlock])` or the `torch.serialization.safe_globals([torchao.quantization.granularity.PerBlock])` context manager to allowlist this global if you trust this class/function. ``` **Test Plan:** ``` python test/core/test_config.py -k test_granularity_serialization ```
1 parent 2ff1eb2 commit b55713a

File tree

5 files changed

+54
-13
lines changed

5 files changed

+54
-13
lines changed

test/core/test_config.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import json
88
import os
9+
import subprocess
910
import tempfile
1011
import warnings
1112
from dataclasses import dataclass
@@ -23,7 +24,11 @@
2324
AWQConfig,
2425
AWQStep,
2526
)
26-
from torchao.quantization import PerBlock
27+
from torchao.quantization import (
28+
PerBlock,
29+
PerRow,
30+
PerTensor,
31+
)
2732
from torchao.quantization.quant_api import (
2833
Float8DynamicActivationFloat8WeightConfig,
2934
Float8DynamicActivationInt4WeightConfig,
@@ -36,10 +41,11 @@
3641
Int8DynamicActivationInt8WeightConfig,
3742
Int8WeightOnlyConfig,
3843
ModuleFqnToConfig,
39-
PerRow,
4044
UIntXWeightOnlyConfig,
45+
quantize_,
4146
)
4247
from torchao.sparsity.sparse_api import BlockSparseWeightConfig, SemiSparseWeightConfig
48+
from torchao.utils import is_sm_at_least_89
4349

4450
# Define test configurations as fixtures
4551
configs = [
@@ -155,6 +161,43 @@ def test_reconstructable_dict_file_round_trip(config):
155161
os.unlink(temp_file_path)
156162

157163

164+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
165+
@pytest.mark.skipif(not is_sm_at_least_89(), reason="needs CUDA capability 8.9+")
166+
@pytest.mark.parametrize(
167+
"granularity",
168+
[
169+
PerTensor(),
170+
PerRow(),
171+
(PerBlock([1, 128]), PerBlock([128, 128])),
172+
],
173+
)
174+
def test_granularity_serialization(granularity):
175+
"""
176+
Ensure that only `import torchao` is needed to load granularities used
177+
in `Float8DynamicActivationFloat8WeightConfig`.
178+
"""
179+
180+
m = torch.nn.Linear(128, 256, bias=False, dtype=torch.bfloat16, device="cuda")
181+
fname = None
182+
with tempfile.NamedTemporaryFile(delete=False, mode="w") as f:
183+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
184+
quantize_(m, config=config)
185+
torch.save(m.state_dict(), f.name)
186+
fname = f.name
187+
188+
assert fname is not None
189+
190+
code = f"""
191+
import torch
192+
import torchao
193+
_ = torch.load('{fname}', weights_only=True)
194+
"""
195+
196+
subprocess_out = subprocess.run(["python"], input=code, text=True)
197+
os.remove(fname)
198+
assert subprocess_out.returncode == 0, "failed weights-only load"
199+
200+
158201
# Define a dummy config in a non-allowed module
159202
@dataclass
160203
class DummyNonAllowedConfig(AOBaseConfig):

test/dtypes/test_affine_quantized_tensor_parallel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
Int4WeightOnlyConfig,
2222
Int8DynamicActivationInt8WeightConfig,
2323
Int8WeightOnlyConfig,
24+
PerRow,
25+
PerTensor,
2426
)
25-
from torchao.quantization.observer import PerRow, PerTensor
2627
from torchao.quantization.quant_api import quantize_
2728

2829
if common_utils.SEED is None:

torchao/quantization/granularity.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from dataclasses import dataclass
88

9+
import torch
10+
911

1012
@dataclass(frozen=True)
1113
class Granularity:
@@ -138,3 +140,6 @@ class PerBlock(Granularity):
138140
# list. Example error:
139141
# https://gist.github.com/vkuzo/ab4d6aec83cb98ad9417898d2c024a2c
140142
block_size: tuple[int, ...]
143+
144+
145+
torch.serialization.add_safe_globals([PerBlock, PerRow, PerTensor])

torchao/quantization/observer.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,7 @@
1212

1313
from torchao.quantization.quant_primitives import _fake_quantize_affine
1414

15-
from .granularity import (
16-
Granularity,
17-
PerRow,
18-
PerTensor,
19-
)
15+
from .granularity import Granularity
2016
from .quant_primitives import (
2117
MappingType,
2218
ZeroPointDomain,
@@ -350,7 +346,3 @@ def calculate_qparams(self):
350346
self.preserve_zero,
351347
self.zero_point_domain,
352348
)
353-
354-
355-
# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True`
356-
torch.serialization.add_safe_globals([PerRow, PerTensor])

torchao/quantization/quant_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,10 +332,10 @@ def insert_observers_(
332332
```
333333
import torch
334334
import torch.nn as nn
335+
from torchao.quantization import PerTensor
335336
from torchao.quantization.linear_observer_tensor import insert_observers_
336337
from torchao.quantization.observer import (
337338
AffineQuantizedMinMaxObserver,
338-
PerTensor,
339339
MappingType
340340
)
341341

0 commit comments

Comments
 (0)