From 898bf6a19fb08aabbf0d082d7eb35831f66fbdd0 Mon Sep 17 00:00:00 2001 From: Riyad Islam Date: Mon, 3 Nov 2025 12:55:09 -0800 Subject: [PATCH 1/2] pytorch geometric quantization support Signed-off-by: Riyad Islam --- CHANGELOG.rst | 7 + .../torch/quantization/plugins/__init__.py | 4 + .../quantization/plugins/pytorch_geometric.py | 89 +++++++++ setup.py | 1 + .../plugins/test_pytorch_geometric_plugin.py | 189 ++++++++++++++++++ 5 files changed, 290 insertions(+) create mode 100644 modelopt/torch/quantization/plugins/pytorch_geometric.py create mode 100644 tests/unit/torch/quantization/plugins/test_pytorch_geometric_plugin.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 886ab7499..c0bc1fb9a 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,5 +1,12 @@ Model Optimizer Changelog (Linux) ================================= +0.41 (2025-12-xx) +^^^^^^^^^^^^^^^^^ + +**Deprecations** + +**New Features** +- Add support for PyTorch Geometric quantization. 0.40 (2025-12-xx) ^^^^^^^^^^^^^^^^^ diff --git a/modelopt/torch/quantization/plugins/__init__.py b/modelopt/torch/quantization/plugins/__init__.py index 531c2e476..ef90dfda5 100644 --- a/modelopt/torch/quantization/plugins/__init__.py +++ b/modelopt/torch/quantization/plugins/__init__.py @@ -25,6 +25,7 @@ - :meth:`huggingface` - :meth:`megatron` - :meth:`peft` +- :meth:`pytorch_geometric` - :meth:`transformer_engine` """ @@ -57,6 +58,9 @@ with import_plugin("peft"): from .peft import * +with import_plugin("torch_geometric"): + from .pytorch_geometric import * + with import_plugin("transformer_engine"): from .transformer_engine import * diff --git a/modelopt/torch/quantization/plugins/pytorch_geometric.py b/modelopt/torch/quantization/plugins/pytorch_geometric.py new file mode 100644 index 000000000..5b2013a34 --- /dev/null +++ b/modelopt/torch/quantization/plugins/pytorch_geometric.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PyTorch Geometric quantization plugin. + +This plugin enables quantization support for PyTorch Geometric (PyG) layers by registering +PyG's custom Linear layer with ModelOpt's quantization registry. + +Example: + >>> import modelopt.torch.quantization as mtq + >>> from torch_geometric.nn import GATConv + >>> + >>> # Create a model with PyG layers + >>> class GATModel(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.gat1 = GATConv(10, 64, heads=4) + ... self.gat2 = GATConv(64 * 4, 32, heads=1) + >>> model = GATModel() + >>> # PyG layers are now automatically quantizable! + >>> quantized_model = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate) +""" + +import torch +from torch_geometric.nn.dense.linear import Linear as PyGLinear + +from modelopt.torch.quantization.nn.modules.quant_module import ( + QuantLinearConvBase, + QuantModuleRegistry, +) +from modelopt.torch.quantization.tensor_quant import QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW + + +class QuantPyGLinear(QuantLinearConvBase): + """Quantized version of PyTorch Geometric's Linear layer. + + PyTorch Geometric uses a custom Linear layer that is functionally equivalent to + torch.nn.Linear but has a different API (in_channels/out_channels instead of + in_features/out_features). This class enables quantization of PyG Linear layers. + + Note: + Many PyTorch Geometric layers (GCNConv, GATConv, SAGEConv, TransformerConv, etc.) + internally use PyG Linear layers, so registering this class enables quantization + for a wide range of graph neural network layers. + """ + + default_quant_desc_weight = QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW + + def forward(self, input, *args, **kwargs): + """Forward pass with quantization. + + Args: + input: Input tensor to the linear layer + *args: Additional positional arguments + **kwargs: Additional keyword arguments + + Returns: + Quantized output tensor + """ + # Quantize input activations + input_q = self.input_quantizer(input) + + # Quantize weights + weight_q = self.weight_quantizer(self.weight) + + # Perform linear operation + output = torch.nn.functional.linear( + input_q, + weight_q, + self.bias if hasattr(self, "bias") and self.bias is not None else None, + ) + + # Quantize output (typically disabled by default) + return self.output_quantizer(output) + + +QuantModuleRegistry.register({PyGLinear: "torch_geometric.nn.dense.linear.Linear"})(QuantPyGLinear) diff --git a/setup.py b/setup.py index 67bf114ae..ad97b7589 100644 --- a/setup.py +++ b/setup.py @@ -79,6 +79,7 @@ "pytest-timeout", "timm", "torchvision", + "torch-geometric", "tox>4.18", "tox-current-env>=0.0.12", ], diff --git a/tests/unit/torch/quantization/plugins/test_pytorch_geometric_plugin.py b/tests/unit/torch/quantization/plugins/test_pytorch_geometric_plugin.py new file mode 100644 index 000000000..6fed35d0e --- /dev/null +++ b/tests/unit/torch/quantization/plugins/test_pytorch_geometric_plugin.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for PyTorch Geometric quantization plugin.""" + +import pytest +import torch +import torch.nn as nn +from _test_utils.torch.misc import set_seed +from torch_geometric.nn import GATConv, GCNConv, SAGEConv, TransformerConv + +import modelopt.torch.quantization as mtq + + +class TestPyTorchGeometricPlugin: + """Test PyTorch Geometric quantization support.""" + + @pytest.fixture(autouse=True) + def setup_seed(self): + """Set seed before each test function.""" + set_seed() + + @pytest.fixture + def device(self): + """Get test device.""" + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def create_graph_data(self, batch_size=2, num_nodes=20, in_channels=16, device="cpu"): + """Create sample graph data for testing.""" + x = torch.randn(batch_size * num_nodes, in_channels, device=device) + # Create batch assignment + batch = torch.cat([torch.full((num_nodes,), i, device=device) for i in range(batch_size)]) + + # Create edge indices for each graph + edge_list = [] + offset = 0 + for _ in range(batch_size): + # Create random edges within each graph + src = torch.randint(0, num_nodes, (50,), device=device) + offset + dst = torch.randint(0, num_nodes, (50,), device=device) + offset + edge_list.append(torch.stack([src, dst])) + offset += num_nodes + + edge_index = torch.cat(edge_list, dim=1) + edge_attr = torch.randn(edge_index.size(1), 32, device=device) + + return x, edge_index, edge_attr, batch + + def test_gat_conv_quantization(self, device): + """Test GATConv layer quantization.""" + + class GATModel(nn.Module): + def __init__(self): + super().__init__() + self.gat1 = GATConv(16, 64, heads=4, edge_dim=32) + self.gat2 = GATConv(256, 32, heads=1, edge_dim=32) + + def forward(self, x, edge_index, edge_attr): + x = torch.relu(self.gat1(x, edge_index, edge_attr)) + return self.gat2(x, edge_index, edge_attr) + + model = GATModel().to(device) + + # Calibration function + def calibrate(m): + m.eval() + with torch.no_grad(): + for _ in range(5): + x, edge_index, edge_attr, _ = self.create_graph_data(device=device) + _ = m(x, edge_index, edge_attr) + + # Quantize model + quantized = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate) + + # Verify quantization + quantizer_count = sum( + 1 for _, m in quantized.named_modules() if "quantizer" in type(m).__name__.lower() + ) + assert quantizer_count > 0, "No quantizers were inserted" + + # Test forward pass + x, edge_index, edge_attr, _ = self.create_graph_data(device=device) + with torch.no_grad(): + output = quantized(x, edge_index, edge_attr) + assert output is not None + + def test_multiple_layer_types(self, device): + """Test quantization of multiple PyG layer types.""" + + class MultiLayerGNN(nn.Module): + def __init__(self): + super().__init__() + self.gcn = GCNConv(16, 32) + self.sage = SAGEConv(32, 64) + self.transformer = TransformerConv(64, 32, heads=2) + + def forward(self, x, edge_index): + x = torch.relu(self.gcn(x, edge_index)) + x = torch.relu(self.sage(x, edge_index)) + return self.transformer(x, edge_index) + + model = MultiLayerGNN().to(device) + + # Calibration + def calibrate(m): + m.eval() + with torch.no_grad(): + for _ in range(3): + x = torch.randn(50, 16, device=device) + edge_index = torch.randint(0, 50, (2, 100), device=device) + _ = m(x, edge_index) + + # Quantize + quantized = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate) + + # Check that PyG Linear layers were quantized + pyg_linear_count = 0 + for name, module in model.named_modules(): + if hasattr(module, "lin") and "torch_geometric" in str(type(module.lin)): + pyg_linear_count += 1 + + quantizer_count = sum( + 1 for _, m in quantized.named_modules() if "quantizer" in type(m).__name__.lower() + ) + + # Each PyG linear should have at least 2 quantizers (input, weight) + assert quantizer_count >= pyg_linear_count * 2, ( + f"Expected at least {pyg_linear_count * 2} quantizers, got {quantizer_count}" + ) + + def test_quantization_accuracy(self, device): + """Test that quantization maintains reasonable accuracy.""" + # Set seed for this test specifically to ensure reproducibility + set_seed() + + model = GATConv(16, 32, heads=2, edge_dim=16).to(device) + + # Create test data + x, edge_index, edge_attr, _ = self.create_graph_data( + batch_size=1, in_channels=16, device=device + ) + edge_attr = edge_attr[:, :16] # Match edge_dim + + # Get original output + model.eval() + with torch.no_grad(): + original_output = model(x, edge_index, edge_attr) + + # Calibration with multiple samples for more stable quantization + def calibrate(m): + m.eval() + with torch.no_grad(): + # Use multiple calibration samples for better stability + for _ in range(5): + x_cal, edge_index_cal, edge_attr_cal, _ = self.create_graph_data( + batch_size=1, in_channels=16, device=device + ) + edge_attr_cal = edge_attr_cal[:, :16] # Match edge_dim + _ = m(x_cal, edge_index_cal, edge_attr_cal) + + # Quantize + quantized = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate) + + # Get quantized output + with torch.no_grad(): + quantized_output = quantized(x, edge_index, edge_attr) + + # Check relative error + abs_diff = torch.abs(original_output - quantized_output) + relative_error = abs_diff / (torch.abs(original_output) + 1e-8) + mean_relative_error = relative_error.mean().item() + + assert mean_relative_error < 0.1, f"Quantization error too large: {mean_relative_error:.2%}" + + +if __name__ == "__main__": + pytest.main([__file__]) From 15022ad335a91e32302bf9fa694aab8397a8c185 Mon Sep 17 00:00:00 2001 From: Riyad Islam Date: Sun, 9 Nov 2025 00:39:32 -0800 Subject: [PATCH 2/2] Addressing reviews Signed-off-by: Riyad Islam --- .../quantization/plugins/pytorch_geometric.py | 36 ++++--------------- .../plugins/test_pytorch_geometric_plugin.py | 4 --- 2 files changed, 7 insertions(+), 33 deletions(-) diff --git a/modelopt/torch/quantization/plugins/pytorch_geometric.py b/modelopt/torch/quantization/plugins/pytorch_geometric.py index 5b2013a34..35ed887cb 100644 --- a/modelopt/torch/quantization/plugins/pytorch_geometric.py +++ b/modelopt/torch/quantization/plugins/pytorch_geometric.py @@ -33,7 +33,6 @@ >>> quantized_model = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate) """ -import torch from torch_geometric.nn.dense.linear import Linear as PyGLinear from modelopt.torch.quantization.nn.modules.quant_module import ( @@ -48,7 +47,13 @@ class QuantPyGLinear(QuantLinearConvBase): PyTorch Geometric uses a custom Linear layer that is functionally equivalent to torch.nn.Linear but has a different API (in_channels/out_channels instead of - in_features/out_features). This class enables quantization of PyG Linear layers. + in_features/out_features). This class enables quantization of PyG Linear layers + by inheriting from QuantLinearConvBase, which handles all quantization logic. + + The quantization is handled automatically by the base classes: + - Input quantization: Handled by QuantInputBase.forward() + - Weight quantization: Handled by QuantLinearConvBase's dynamic weight attribute + - Output quantization: Handled by QuantInputBase.forward() Note: Many PyTorch Geometric layers (GCNConv, GATConv, SAGEConv, TransformerConv, etc.) @@ -58,32 +63,5 @@ class QuantPyGLinear(QuantLinearConvBase): default_quant_desc_weight = QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW - def forward(self, input, *args, **kwargs): - """Forward pass with quantization. - - Args: - input: Input tensor to the linear layer - *args: Additional positional arguments - **kwargs: Additional keyword arguments - - Returns: - Quantized output tensor - """ - # Quantize input activations - input_q = self.input_quantizer(input) - - # Quantize weights - weight_q = self.weight_quantizer(self.weight) - - # Perform linear operation - output = torch.nn.functional.linear( - input_q, - weight_q, - self.bias if hasattr(self, "bias") and self.bias is not None else None, - ) - - # Quantize output (typically disabled by default) - return self.output_quantizer(output) - QuantModuleRegistry.register({PyGLinear: "torch_geometric.nn.dense.linear.Linear"})(QuantPyGLinear) diff --git a/tests/unit/torch/quantization/plugins/test_pytorch_geometric_plugin.py b/tests/unit/torch/quantization/plugins/test_pytorch_geometric_plugin.py index 6fed35d0e..a3abb7195 100644 --- a/tests/unit/torch/quantization/plugins/test_pytorch_geometric_plugin.py +++ b/tests/unit/torch/quantization/plugins/test_pytorch_geometric_plugin.py @@ -183,7 +183,3 @@ def calibrate(m): mean_relative_error = relative_error.mean().item() assert mean_relative_error < 0.1, f"Quantization error too large: {mean_relative_error:.2%}" - - -if __name__ == "__main__": - pytest.main([__file__])