diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 886ab7499..c0bc1fb9a 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,5 +1,12 @@ Model Optimizer Changelog (Linux) ================================= +0.41 (2025-12-xx) +^^^^^^^^^^^^^^^^^ + +**Deprecations** + +**New Features** +- Add support for PyTorch Geometric quantization. 0.40 (2025-12-xx) ^^^^^^^^^^^^^^^^^ diff --git a/modelopt/torch/quantization/plugins/__init__.py b/modelopt/torch/quantization/plugins/__init__.py index 531c2e476..ef90dfda5 100644 --- a/modelopt/torch/quantization/plugins/__init__.py +++ b/modelopt/torch/quantization/plugins/__init__.py @@ -25,6 +25,7 @@ - :meth:`huggingface` - :meth:`megatron` - :meth:`peft` +- :meth:`pytorch_geometric` - :meth:`transformer_engine` """ @@ -57,6 +58,9 @@ with import_plugin("peft"): from .peft import * +with import_plugin("torch_geometric"): + from .pytorch_geometric import * + with import_plugin("transformer_engine"): from .transformer_engine import * diff --git a/modelopt/torch/quantization/plugins/pytorch_geometric.py b/modelopt/torch/quantization/plugins/pytorch_geometric.py new file mode 100644 index 000000000..35ed887cb --- /dev/null +++ b/modelopt/torch/quantization/plugins/pytorch_geometric.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PyTorch Geometric quantization plugin. + +This plugin enables quantization support for PyTorch Geometric (PyG) layers by registering +PyG's custom Linear layer with ModelOpt's quantization registry. + +Example: + >>> import modelopt.torch.quantization as mtq + >>> from torch_geometric.nn import GATConv + >>> + >>> # Create a model with PyG layers + >>> class GATModel(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.gat1 = GATConv(10, 64, heads=4) + ... self.gat2 = GATConv(64 * 4, 32, heads=1) + >>> model = GATModel() + >>> # PyG layers are now automatically quantizable! + >>> quantized_model = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate) +""" + +from torch_geometric.nn.dense.linear import Linear as PyGLinear + +from modelopt.torch.quantization.nn.modules.quant_module import ( + QuantLinearConvBase, + QuantModuleRegistry, +) +from modelopt.torch.quantization.tensor_quant import QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW + + +class QuantPyGLinear(QuantLinearConvBase): + """Quantized version of PyTorch Geometric's Linear layer. + + PyTorch Geometric uses a custom Linear layer that is functionally equivalent to + torch.nn.Linear but has a different API (in_channels/out_channels instead of + in_features/out_features). This class enables quantization of PyG Linear layers + by inheriting from QuantLinearConvBase, which handles all quantization logic. + + The quantization is handled automatically by the base classes: + - Input quantization: Handled by QuantInputBase.forward() + - Weight quantization: Handled by QuantLinearConvBase's dynamic weight attribute + - Output quantization: Handled by QuantInputBase.forward() + + Note: + Many PyTorch Geometric layers (GCNConv, GATConv, SAGEConv, TransformerConv, etc.) + internally use PyG Linear layers, so registering this class enables quantization + for a wide range of graph neural network layers. + """ + + default_quant_desc_weight = QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW + + +QuantModuleRegistry.register({PyGLinear: "torch_geometric.nn.dense.linear.Linear"})(QuantPyGLinear) diff --git a/setup.py b/setup.py index 67bf114ae..ad97b7589 100644 --- a/setup.py +++ b/setup.py @@ -79,6 +79,7 @@ "pytest-timeout", "timm", "torchvision", + "torch-geometric", "tox>4.18", "tox-current-env>=0.0.12", ], diff --git a/tests/unit/torch/quantization/plugins/test_pytorch_geometric_plugin.py b/tests/unit/torch/quantization/plugins/test_pytorch_geometric_plugin.py new file mode 100644 index 000000000..a3abb7195 --- /dev/null +++ b/tests/unit/torch/quantization/plugins/test_pytorch_geometric_plugin.py @@ -0,0 +1,185 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for PyTorch Geometric quantization plugin.""" + +import pytest +import torch +import torch.nn as nn +from _test_utils.torch.misc import set_seed +from torch_geometric.nn import GATConv, GCNConv, SAGEConv, TransformerConv + +import modelopt.torch.quantization as mtq + + +class TestPyTorchGeometricPlugin: + """Test PyTorch Geometric quantization support.""" + + @pytest.fixture(autouse=True) + def setup_seed(self): + """Set seed before each test function.""" + set_seed() + + @pytest.fixture + def device(self): + """Get test device.""" + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def create_graph_data(self, batch_size=2, num_nodes=20, in_channels=16, device="cpu"): + """Create sample graph data for testing.""" + x = torch.randn(batch_size * num_nodes, in_channels, device=device) + # Create batch assignment + batch = torch.cat([torch.full((num_nodes,), i, device=device) for i in range(batch_size)]) + + # Create edge indices for each graph + edge_list = [] + offset = 0 + for _ in range(batch_size): + # Create random edges within each graph + src = torch.randint(0, num_nodes, (50,), device=device) + offset + dst = torch.randint(0, num_nodes, (50,), device=device) + offset + edge_list.append(torch.stack([src, dst])) + offset += num_nodes + + edge_index = torch.cat(edge_list, dim=1) + edge_attr = torch.randn(edge_index.size(1), 32, device=device) + + return x, edge_index, edge_attr, batch + + def test_gat_conv_quantization(self, device): + """Test GATConv layer quantization.""" + + class GATModel(nn.Module): + def __init__(self): + super().__init__() + self.gat1 = GATConv(16, 64, heads=4, edge_dim=32) + self.gat2 = GATConv(256, 32, heads=1, edge_dim=32) + + def forward(self, x, edge_index, edge_attr): + x = torch.relu(self.gat1(x, edge_index, edge_attr)) + return self.gat2(x, edge_index, edge_attr) + + model = GATModel().to(device) + + # Calibration function + def calibrate(m): + m.eval() + with torch.no_grad(): + for _ in range(5): + x, edge_index, edge_attr, _ = self.create_graph_data(device=device) + _ = m(x, edge_index, edge_attr) + + # Quantize model + quantized = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate) + + # Verify quantization + quantizer_count = sum( + 1 for _, m in quantized.named_modules() if "quantizer" in type(m).__name__.lower() + ) + assert quantizer_count > 0, "No quantizers were inserted" + + # Test forward pass + x, edge_index, edge_attr, _ = self.create_graph_data(device=device) + with torch.no_grad(): + output = quantized(x, edge_index, edge_attr) + assert output is not None + + def test_multiple_layer_types(self, device): + """Test quantization of multiple PyG layer types.""" + + class MultiLayerGNN(nn.Module): + def __init__(self): + super().__init__() + self.gcn = GCNConv(16, 32) + self.sage = SAGEConv(32, 64) + self.transformer = TransformerConv(64, 32, heads=2) + + def forward(self, x, edge_index): + x = torch.relu(self.gcn(x, edge_index)) + x = torch.relu(self.sage(x, edge_index)) + return self.transformer(x, edge_index) + + model = MultiLayerGNN().to(device) + + # Calibration + def calibrate(m): + m.eval() + with torch.no_grad(): + for _ in range(3): + x = torch.randn(50, 16, device=device) + edge_index = torch.randint(0, 50, (2, 100), device=device) + _ = m(x, edge_index) + + # Quantize + quantized = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate) + + # Check that PyG Linear layers were quantized + pyg_linear_count = 0 + for name, module in model.named_modules(): + if hasattr(module, "lin") and "torch_geometric" in str(type(module.lin)): + pyg_linear_count += 1 + + quantizer_count = sum( + 1 for _, m in quantized.named_modules() if "quantizer" in type(m).__name__.lower() + ) + + # Each PyG linear should have at least 2 quantizers (input, weight) + assert quantizer_count >= pyg_linear_count * 2, ( + f"Expected at least {pyg_linear_count * 2} quantizers, got {quantizer_count}" + ) + + def test_quantization_accuracy(self, device): + """Test that quantization maintains reasonable accuracy.""" + # Set seed for this test specifically to ensure reproducibility + set_seed() + + model = GATConv(16, 32, heads=2, edge_dim=16).to(device) + + # Create test data + x, edge_index, edge_attr, _ = self.create_graph_data( + batch_size=1, in_channels=16, device=device + ) + edge_attr = edge_attr[:, :16] # Match edge_dim + + # Get original output + model.eval() + with torch.no_grad(): + original_output = model(x, edge_index, edge_attr) + + # Calibration with multiple samples for more stable quantization + def calibrate(m): + m.eval() + with torch.no_grad(): + # Use multiple calibration samples for better stability + for _ in range(5): + x_cal, edge_index_cal, edge_attr_cal, _ = self.create_graph_data( + batch_size=1, in_channels=16, device=device + ) + edge_attr_cal = edge_attr_cal[:, :16] # Match edge_dim + _ = m(x_cal, edge_index_cal, edge_attr_cal) + + # Quantize + quantized = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate) + + # Get quantized output + with torch.no_grad(): + quantized_output = quantized(x, edge_index, edge_attr) + + # Check relative error + abs_diff = torch.abs(original_output - quantized_output) + relative_error = abs_diff / (torch.abs(original_output) + 1e-8) + mean_relative_error = relative_error.mean().item() + + assert mean_relative_error < 0.1, f"Quantization error too large: {mean_relative_error:.2%}"