From 3b3722fe4d4016948ed519e136dd5f52b5a83f80 Mon Sep 17 00:00:00 2001 From: Parth Shandilya <24358501+ParthS007@users.noreply.github.com> Date: Thu, 30 Oct 2025 09:55:34 +0100 Subject: [PATCH 1/4] Add automatic clipping optimizers for distributed and non-distributed training --- opacus/optimizers/__init__.py | 20 ++ .../ddpoptimizer_automatic_clipping.py | 221 ++++++++++++++++++ .../optimizer_automatic_clipping.py | 150 ++++++++++++ 3 files changed, 391 insertions(+) create mode 100644 opacus/optimizers/ddpoptimizer_automatic_clipping.py create mode 100644 opacus/optimizers/optimizer_automatic_clipping.py diff --git a/opacus/optimizers/__init__.py b/opacus/optimizers/__init__.py index 00e27521..707fa06b 100644 --- a/opacus/optimizers/__init__.py +++ b/opacus/optimizers/__init__.py @@ -22,6 +22,14 @@ from .optimizer import DPOptimizer from .optimizer_fast_gradient_clipping import DPOptimizerFastGradientClipping from .perlayeroptimizer import DPPerLayerOptimizer +from .optimizer_automatic_clipping import ( + DPAutomaticClippingOptimizer, + DPPerLayerAutomaticClippingOptimizer, +) +from .ddpoptimizer_automatic_clipping import ( + DistributedDPAutomaticClippingOptimizer, + DistributedDPPerLayerAutomaticClippingOptimizer, +) __all__ = [ @@ -33,6 +41,10 @@ "FSDPOptimizerFastGradientClipping", "DPPerLayerOptimizer", "SimpleDistributedPerLayerOptimizer", + "DPAutomaticClippingOptimizer", + "DPPerLayerAutomaticClippingOptimizer", + "DistributedDPAutomaticClippingOptimizer", + "DistributedDPPerLayerAutomaticClippingOptimizer", ] @@ -64,6 +76,14 @@ def get_optimizer_class(clipping: str, distributed: bool, grad_sample_mode: str return SimpleDistributedPerLayerOptimizer else: raise ValueError(f"Unexpected grad_sample_mode: {grad_sample_mode}") + elif clipping == "automatic" and distributed is False: + return DPAutomaticClippingOptimizer + elif clipping == "automatic" and distributed is True: + return DistributedDPAutomaticClippingOptimizer + elif clipping == "automatic_per_layer" and distributed is False: + return DPPerLayerAutomaticClippingOptimizer + elif clipping == "automatic_per_layer" and distributed is True: + return DistributedDPPerLayerAutomaticClippingOptimizer elif clipping == "adaptive" and distributed is False: return AdaClipDPOptimizer raise ValueError( diff --git a/opacus/optimizers/ddpoptimizer_automatic_clipping.py b/opacus/optimizers/ddpoptimizer_automatic_clipping.py new file mode 100644 index 00000000..6548fc12 --- /dev/null +++ b/opacus/optimizers/ddpoptimizer_automatic_clipping.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Callable, List, Optional + +import torch +from torch.optim import Optimizer + +from opacus.optimizers.optimizer_automatic_clipping import ( + DPAutomaticClippingOptimizer, + DPPerLayerAutomaticClippingOptimizer, +) + + +class DistributedDPAutomaticClippingOptimizer(DPAutomaticClippingOptimizer): + """ + Distributed version of DPAutomaticClippingOptimizer for multi-GPU training. + + This optimizer extends :class:`~opacus.optimizers.optimizer_automatic_clipping.DPAutomaticClippingOptimizer` + to work with PyTorch's distributed data parallel (DDP) training. It handles: + + - **Gradient Synchronization**: Uses ``all_reduce`` to sum gradients across all workers + - **Coordinated Noise**: Only rank 0 generates noise to ensure consistency + - **Proper Reduction**: Handles mean/sum loss reduction across workers + + The automatic clipping formula remains: + ``per_sample_clip_factor = max_grad_norm / (per_sample_norms + 0.01)`` + + As described in: + *"Automatic Clipping: Differentially Private Deep Learning Made Easier and Stronger"* + https://arxiv.org/pdf/2206.07136 + + Args: + optimizer: Wrapped optimizer instance + noise_multiplier: Noise multiplier for differential privacy + max_grad_norm: Maximum gradient norm for clipping + expected_batch_size: Expected batch size (per worker) + loss_reduction: How to reduce loss across workers ("mean" or "sum") + generator: Random number generator for noise + secure_mode: Whether to use secure random number generation + + Note: + This optimizer is automatically instantiated when using + ``PrivacyEngine.make_private()`` with ``distributed=True`` and ``clipping="automatic"``. + + See Also: + - :class:`~opacus.optimizers.optimizer_automatic_clipping.DPAutomaticClippingOptimizer`: Non-distributed version + - :class:`~DistributedDPPerLayerAutomaticClippingOptimizer`: Per-layer distributed variant + """ + + def __init__( + self, + optimizer: Optimizer, + *, + noise_multiplier: float, + max_grad_norm: float, + expected_batch_size: Optional[int], + loss_reduction: str = "mean", + generator=None, + secure_mode: bool = False, + **kwargs, + ): + super().__init__( + optimizer=optimizer, + noise_multiplier=noise_multiplier, + max_grad_norm=max_grad_norm, + expected_batch_size=expected_batch_size, + loss_reduction=loss_reduction, + generator=generator, + secure_mode=secure_mode, + **kwargs, + ) + self.rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + + def add_noise(self): + """Add noise only on rank 0, then broadcast to other workers.""" + if self.rank == 0: + super().add_noise() + else: + for p in self.params: + p.grad = p.summed_grad.view_as(p) + + def reduce_gradients(self): + """Reduce gradients across all workers.""" + for p in self.params: + if not p.requires_grad: + continue + torch.distributed.all_reduce(p.grad, op=torch.distributed.ReduceOp.SUM) + if self.loss_reduction == "mean": + p.grad /= self.world_size + + def step( + self, closure: Optional[Callable[[], float]] = None + ) -> Optional[torch.Tensor]: + """Perform optimization step with distributed gradient synchronization.""" + if closure is not None: + with torch.enable_grad(): + closure() + + if self.pre_step(): + self.reduce_gradients() + return self.original_optimizer.step() + else: + return None + + +class DistributedDPPerLayerAutomaticClippingOptimizer( + DPPerLayerAutomaticClippingOptimizer +): + """ + Distributed per-layer automatic clipping optimizer for multi-GPU training. + + This optimizer extends :class:`~opacus.optimizers.optimizer_automatic_clipping.DPPerLayerAutomaticClippingOptimizer` + to work with PyTorch's distributed data parallel (DDP) training. + + Combines the benefits of: + - **Per-layer clipping**: Each layer has its own ``max_grad_norm`` + - **Automatic clipping**: Smooth scaling without hard clamping + - **Distributed training**: Gradient synchronization across workers + + The per-layer automatic clipping formula is: + ``per_sample_clip_factor = max_grad_norm[layer] / (per_sample_norms[layer] + 0.01)`` + + As described in: + *"Automatic Clipping: Differentially Private Deep Learning Made Easier and Stronger"* + https://arxiv.org/pdf/2206.07136 + + Args: + optimizer: Wrapped optimizer instance + noise_multiplier: Noise multiplier for differential privacy + max_grad_norm: List of maximum gradient norms (one per parameter) + expected_batch_size: Expected batch size (per worker) + loss_reduction: How to reduce loss across workers ("mean" or "sum") + generator: Random number generator for noise + secure_mode: Whether to use secure random number generation + + Note: + This optimizer is automatically instantiated when using + ``PrivacyEngine.make_private()`` with ``distributed=True`` and + ``clipping="automatic_per_layer"``. + + See Also: + - :class:`~opacus.optimizers.optimizer_automatic_clipping.DPPerLayerAutomaticClippingOptimizer`: Non-distributed version + - :class:`~DistributedDPAutomaticClippingOptimizer`: All-layer distributed variant + """ + + def __init__( + self, + optimizer: Optimizer, + *, + noise_multiplier: float, + max_grad_norm: List[float], + expected_batch_size: Optional[int], + loss_reduction: str = "mean", + generator=None, + secure_mode: bool = False, + **kwargs, + ): + super().__init__( + optimizer=optimizer, + noise_multiplier=noise_multiplier, + max_grad_norm=max_grad_norm, + expected_batch_size=expected_batch_size, + loss_reduction=loss_reduction, + generator=generator, + secure_mode=secure_mode, + **kwargs, + ) + self.rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + + def add_noise(self): + """Add noise only on rank 0, then broadcast to other workers.""" + if self.rank == 0: + super().add_noise() + else: + for p in self.params: + p.grad = p.summed_grad.view_as(p) + + def reduce_gradients(self): + """Reduce gradients across all workers.""" + for p in self.params: + if not p.requires_grad: + continue + torch.distributed.all_reduce(p.grad, op=torch.distributed.ReduceOp.SUM) + if self.loss_reduction == "mean": + p.grad /= self.world_size + + def step( + self, closure: Optional[Callable[[], float]] = None + ) -> Optional[torch.Tensor]: + """Perform optimization step with distributed gradient synchronization.""" + if closure is not None: + with torch.enable_grad(): + closure() + + if self.pre_step(): + self.reduce_gradients() + return self.original_optimizer.step() + else: + return None + + +__all__ = [ + "DistributedDPAutomaticClippingOptimizer", + "DistributedDPPerLayerAutomaticClippingOptimizer", +] diff --git a/opacus/optimizers/optimizer_automatic_clipping.py b/opacus/optimizers/optimizer_automatic_clipping.py new file mode 100644 index 00000000..5b0c4e8a --- /dev/null +++ b/opacus/optimizers/optimizer_automatic_clipping.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from typing import List + +import torch +from opacus.optimizers.optimizer import _check_processed_flag, _mark_as_processed +from opacus.optimizers.perlayeroptimizer import DPPerLayerOptimizer +from opacus.optimizers.optimizer import DPOptimizer + + +class DPAutomaticClippingOptimizer(DPOptimizer): + """ + DPOptimizer variant that uses automatic clipping across all layers. + + Automatic clipping computes per-sample clip factors using the formula: + ``per_sample_clip_factor = max_grad_norm / (per_sample_norms + 0.01)`` + + This differs from the default behavior by using automatic scaling + (without clamping to 1.0) as described in: + *"Automatic Clipping: Differentially Private Deep Learning Made Easier and Stronger"* + https://arxiv.org/pdf/2206.07136 + + The stabilization constant (0.01) prevents division by zero for samples with + very small gradients, ensuring numerical stability during training. + + Note: + This optimizer is automatically instantiated when using + ``PrivacyEngine.make_private()`` with ``clipping="automatic"``. + + See Also: + - :class:`~opacus.optimizers.optimizer.DPOptimizer`: Base DP optimizer with standard clipping + - :class:`~DPPerLayerAutomaticClippingOptimizer`: Per-layer variant of automatic clipping + - :class:`~opacus.optimizers.ddpoptimizer_automatic_clipping.DistributedDPAutomaticClippingOptimizer`: Distributed version + """ + + def clip_and_accumulate(self): + """Perform automatic clipping and accumulate clipped gradients. + + This method mirrors :meth:`DPOptimizer.clip_and_accumulate` but uses the + automatic clipping formula for per-sample clip factors. + """ + + # Compute per-parameter norms (shape: [batch_size] for each parameter) + per_param_norms: List[torch.Tensor] = [ + g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples + ] + + if per_param_norms: + target_device = per_param_norms[0].device + per_param_norms = [norm.to(target_device) for norm in per_param_norms] + + # Combine per-parameter norms to per-sample norms + per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1) + + # Automatic clipping factor (no clamp) + per_sample_clip_factor = self.max_grad_norm / (per_sample_norms + 0.01) + else: + # Empty case: produce an empty tensor on a sensible device + device = ( + self.params[0].device if len(self.params) > 0 else torch.device("cpu") + ) + per_sample_clip_factor = torch.tensor([], device=device) + + for p in self.params: + _check_processed_flag(p.grad_sample) + + grad_sample = self._get_flat_grad_sample(p) + + # cast per-sample gradients to optimizer parameter dtype (e.g., fp32) + grad_sample = grad_sample.to(p.dtype) + + clip_factor_on_device = per_sample_clip_factor.to(grad_sample.device).to( + p.dtype + ) + grad = torch.einsum("i,i...", clip_factor_on_device, grad_sample) + + if p.summed_grad is not None: + p.summed_grad += grad + else: + p.summed_grad = grad + + _mark_as_processed(p.grad_sample) + + +class DPPerLayerAutomaticClippingOptimizer(DPPerLayerOptimizer): + """ + Per-layer variant of automatic clipping. + + For each parameter (layer), we compute the per-sample clip factor using the + corresponding per-layer ``max_grad_norm``: + ``per_sample_clip_factor = max_grad_norm / (per_sample_norms + 0.01)`` + + This allows each layer to have different clipping behavior based on its own + gradient magnitude distribution, which can improve training stability and + utility compared to global clipping. + + This approach is described in: + *"Automatic Clipping: Differentially Private Deep Learning Made Easier and Stronger"* + https://arxiv.org/pdf/2206.07136 + + Note: + This optimizer is automatically instantiated when using + ``PrivacyEngine.make_private()`` with ``clipping="automatic_per_layer"``. + + See Also: + - :class:`~opacus.optimizers.perlayeroptimizer.DPPerLayerOptimizer`: Base per-layer DP optimizer + - :class:`~DPAutomaticClippingOptimizer`: All-layer variant of automatic clipping + - :class:`~opacus.optimizers.ddpoptimizer_automatic_clipping.DistributedDPPerLayerAutomaticClippingOptimizer`: Distributed version + """ + + def clip_and_accumulate(self): + for p, max_grad_norm in zip(self.params, self.max_grad_norms): + _check_processed_flag(p.grad_sample) + + grad_sample = self._get_flat_grad_sample(p) + # per-sample norms for this parameter (collapse parameter dims) + per_sample_norms = grad_sample.norm( + 2, dim=tuple(range(1, grad_sample.ndim)) + ) + + per_sample_clip_factor = max_grad_norm / (per_sample_norms + 0.01) + grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample) + + if p.summed_grad is not None: + p.summed_grad += grad + else: + p.summed_grad = grad + + _mark_as_processed(p.grad_sample) + + +__all__ = [ + "DPAutomaticClippingOptimizer", + "DPPerLayerAutomaticClippingOptimizer", +] From 1ce4db0dbe8251ab15ddb9e6d67323248a98ea2a Mon Sep 17 00:00:00 2001 From: Parth Shandilya <24358501+ParthS007@users.noreply.github.com> Date: Thu, 30 Oct 2025 09:55:44 +0100 Subject: [PATCH 2/4] Add tests for automatic clipping optimizers in single-GPU and multi-GPU settings --- opacus/tests/automatic_clipping.py | 161 ++++++++++++ .../tests/multigpu_automatic_clipping_test.py | 238 ++++++++++++++++++ 2 files changed, 399 insertions(+) create mode 100644 opacus/tests/automatic_clipping.py create mode 100644 opacus/tests/multigpu_automatic_clipping_test.py diff --git a/opacus/tests/automatic_clipping.py b/opacus/tests/automatic_clipping.py new file mode 100644 index 00000000..a3b50bce --- /dev/null +++ b/opacus/tests/automatic_clipping.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for automatic clipping optimizers in single-GPU/non-distributed setting. + +For distributed (multi-GPU) tests, see multigpu_automatic_clipping_test.py. +""" + +import unittest + +import torch +import torch.nn as nn +from opacus import PrivacyEngine +from opacus.optimizers.optimizer_automatic_clipping import ( + DPAutomaticClippingOptimizer, + DPPerLayerAutomaticClippingOptimizer, +) +from torch.utils.data import DataLoader, TensorDataset + + +class SampleConvNet(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(28 * 28, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = x.view(x.shape[0], -1) + x = self.fc1(x) + x = nn.functional.relu(x) + x = self.fc2(x) + return x + + +class AutomaticClippingTest(unittest.TestCase): + def setUp(self): + self.DATA_SIZE = 64 + self.BATCH_SIZE = 16 + self.LR = 0.5 + self.NOISE_MULTIPLIER = 0.0 # No noise for deterministic tests + self.MAX_GRAD_NORM = 1.0 + + def _init_data(self): + data = torch.randn(self.DATA_SIZE, 1, 28, 28) + labels = torch.randint(0, 10, (self.DATA_SIZE,)) + dataset = TensorDataset(data, labels) + return DataLoader(dataset, batch_size=self.BATCH_SIZE) + + def test_automatic_clipping_basic(self): + """Test that automatic clipping mode works end-to-end""" + torch.manual_seed(42) + model = SampleConvNet() + optimizer = torch.optim.SGD(model.parameters(), lr=self.LR) + criterion = nn.CrossEntropyLoss() + data_loader = self._init_data() + + privacy_engine = PrivacyEngine() + model, optimizer, data_loader = privacy_engine.make_private( + module=model, + optimizer=optimizer, + data_loader=data_loader, + noise_multiplier=self.NOISE_MULTIPLIER, + max_grad_norm=self.MAX_GRAD_NORM, + poisson_sampling=False, + clipping="automatic", + ) + + # Run one training step + for x, y in data_loader: + optimizer.zero_grad() + output = model(x) + loss = criterion(output, y) + loss.backward() + optimizer.step() + break # Just one batch + + # Verify optimizer is correct type + self.assertIsInstance(optimizer, DPAutomaticClippingOptimizer) + + def test_automatic_per_layer_clipping_basic(self): + """Test that automatic per-layer clipping mode works end-to-end""" + torch.manual_seed(42) + model = SampleConvNet() + optimizer = torch.optim.SGD(model.parameters(), lr=self.LR) + criterion = nn.CrossEntropyLoss() + data_loader = self._init_data() + + # Get number of parameters for per-layer norms + num_params = len(list(model.parameters())) + max_grad_norm = [self.MAX_GRAD_NORM] * num_params + + privacy_engine = PrivacyEngine() + model, optimizer, data_loader = privacy_engine.make_private( + module=model, + optimizer=optimizer, + data_loader=data_loader, + noise_multiplier=self.NOISE_MULTIPLIER, + max_grad_norm=max_grad_norm, + poisson_sampling=False, + clipping="automatic_per_layer", + ) + + # Run one training step + for x, y in data_loader: + optimizer.zero_grad() + output = model(x) + loss = criterion(output, y) + loss.backward() + optimizer.step() + break # Just one batch + + # Verify optimizer is correct type + self.assertIsInstance(optimizer, DPPerLayerAutomaticClippingOptimizer) + + def test_automatic_clipping_convergence(self): + """Test that automatic clipping allows model to learn (loss decreases)""" + torch.manual_seed(42) + model = SampleConvNet() + optimizer = torch.optim.SGD(model.parameters(), lr=self.LR) + criterion = nn.CrossEntropyLoss() + data_loader = self._init_data() + + privacy_engine = PrivacyEngine() + model, optimizer, data_loader = privacy_engine.make_private( + module=model, + optimizer=optimizer, + data_loader=data_loader, + noise_multiplier=self.NOISE_MULTIPLIER, + max_grad_norm=self.MAX_GRAD_NORM, + poisson_sampling=False, + clipping="automatic", + ) + + losses = [] + for x, y in data_loader: + optimizer.zero_grad() + output = model(x) + loss = criterion(output, y) + loss.backward() + optimizer.step() + losses.append(loss.item()) + + # Check that loss decreased from first to last batch + self.assertLess(losses[-1], losses[0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/opacus/tests/multigpu_automatic_clipping_test.py b/opacus/tests/multigpu_automatic_clipping_test.py new file mode 100644 index 00000000..c57ba53a --- /dev/null +++ b/opacus/tests/multigpu_automatic_clipping_test.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +import torch.optim as optim +from opacus import PrivacyEngine +from opacus.optimizers.ddpoptimizer_automatic_clipping import ( + DistributedDPAutomaticClippingOptimizer, + DistributedDPPerLayerAutomaticClippingOptimizer, +) +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader, TensorDataset +from torch.utils.data.distributed import DistributedSampler + + +def setup(rank, world_size): + if sys.platform == "win32": + raise ValueError("Windows platform is not supported for this test") + else: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12356" + + # initialize the process group + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + torch.distributed.init_process_group( + init_method="env://", + backend="nccl", + ) + + +def cleanup(): + dist.destroy_process_group() + + +class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.net1 = nn.Linear(10, 10) + self.relu = nn.ReLU() + self.net2 = nn.Linear(10, 5) + + def forward(self, x): + return self.net2(self.relu(self.net1(x))) + + +def demo_basic(rank, weight, world_size, dp, clipping): + """ + Test automatic clipping in distributed setting. + + Args: + rank: Process rank + weight: Shared tensor to store model weights for comparison + world_size: Number of processes + dp: Whether to use DP training + clipping: Clipping mode ("automatic" or "automatic_per_layer") + """ + torch.manual_seed(world_size) + batch_size = 2 + setup(rank, world_size) + + # create model and move it to GPU with id rank + model = ToyModel().to(rank) + model.net1.weight.data.zero_() + optimizer = optim.SGD(model.parameters(), lr=1) + + # create dataset + labels = torch.randn(2 * batch_size, 5).to(rank) + data = torch.randn(2 * batch_size, 10) + dataset = TensorDataset(data, labels) + + loss_fn = nn.CrossEntropyLoss() + + max_grad_norm = 1e8 + + ddp_model = DDP(model, device_ids=[rank]) + + privacy_engine = PrivacyEngine() + + sampler = DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=False + ) + data_loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler) + + if dp: + ddp_model, optimizer, data_loader = privacy_engine.make_private( # type: ignore + module=ddp_model, + optimizer=optimizer, + data_loader=data_loader, + noise_multiplier=0, + max_grad_norm=max_grad_norm, + poisson_sampling=False, + clipping=clipping, + grad_sample_mode="hooks", + ) + if clipping == "automatic": + assert isinstance(optimizer, DistributedDPAutomaticClippingOptimizer) + elif clipping == "automatic_per_layer": + assert isinstance(optimizer, DistributedDPPerLayerAutomaticClippingOptimizer) + + for x, y in data_loader: + outputs = ddp_model(x.to(rank)) + loss = loss_fn(outputs, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + break + + weight.copy_(model.net1.weight.data.cpu()) + cleanup() + + +def run_demo(demo_fn, weight, world_size, dp, clipping): + mp.spawn( + demo_fn, + args=(weight, world_size, dp, clipping), + nprocs=world_size, + join=True, + ) + + +class GradientComputationTestAutomaticClipping(unittest.TestCase): + @unittest.skipIf(torch.cuda.device_count() < 2, "Need at least 2 GPUs") + def test_automatic_clipping_gradient_correct(self) -> None: + """ + Tests that gradient is the same with DP automatic clipping or without DP + in the distributed setting (with noise_multiplier=0). + """ + n_gpus = torch.cuda.device_count() + self.assertTrue( + n_gpus >= 2, f"Need at least 2 gpus but was provided only {n_gpus}." + ) + + weight_dp, weight_nodp = torch.zeros(10, 10), torch.zeros(10, 10) + + run_demo( + demo_basic, + weight_dp, + 2, + dp=True, + clipping="automatic", + ) + run_demo( + demo_basic, + weight_nodp, + 2, + dp=False, + clipping=None, + ) + + self.assertTrue( + torch.allclose(weight_dp, weight_nodp, atol=1e-5, rtol=1e-3) + ) + + @unittest.skipIf(torch.cuda.device_count() < 2, "Need at least 2 GPUs") + def test_automatic_per_layer_clipping_gradient_correct(self) -> None: + """ + Tests that gradient is the same with DP automatic per-layer clipping or without DP + in the distributed setting (with noise_multiplier=0). + """ + n_gpus = torch.cuda.device_count() + self.assertTrue( + n_gpus >= 2, f"Need at least 2 gpus but was provided only {n_gpus}." + ) + + weight_dp, weight_nodp = torch.zeros(10, 10), torch.zeros(10, 10) + + run_demo( + demo_basic, + weight_dp, + 2, + dp=True, + clipping="automatic_per_layer", + ) + run_demo( + demo_basic, + weight_nodp, + 2, + dp=False, + clipping=None, + ) + + self.assertTrue( + torch.allclose(weight_dp, weight_nodp, atol=1e-5, rtol=1e-3) + ) + + @unittest.skipIf(torch.cuda.device_count() < 2, "Need at least 2 GPUs") + def test_automatic_clipping_optimizer_type(self) -> None: + """ + Tests that the correct optimizer type is returned when using automatic clipping. + """ + n_gpus = torch.cuda.device_count() + self.assertTrue( + n_gpus >= 2, f"Need at least 2 gpus but was provided only {n_gpus}." + ) + + weight_dp = torch.zeros(10, 10) + + # Test automatic clipping + run_demo( + demo_basic, + weight_dp, + 2, + dp=True, + clipping="automatic", + ) + + # Test automatic per-layer clipping + run_demo( + demo_basic, + weight_dp, + 2, + dp=True, + clipping="automatic_per_layer", + ) + + +if __name__ == "__main__": + unittest.main() From a4365ae2c81487663a778f9ccaa8f6ef86d21334 Mon Sep 17 00:00:00 2001 From: Parth Shandilya <24358501+ParthS007@users.noreply.github.com> Date: Thu, 30 Oct 2025 09:55:51 +0100 Subject: [PATCH 3/4] Add documentation for DistributedDPAutomaticClippingOptimizer and DPAutomaticClippingOptimizer --- .../source/optim/dp_ddp_optimizer_automatic_clipping.rst | 8 ++++++++ .../source/optim/dp_optimizer_automatic_clipping.rst | 8 ++++++++ website/sphinx/source/optim/optimizers.rst | 2 ++ 3 files changed, 18 insertions(+) create mode 100644 website/sphinx/source/optim/dp_ddp_optimizer_automatic_clipping.rst create mode 100644 website/sphinx/source/optim/dp_optimizer_automatic_clipping.rst diff --git a/website/sphinx/source/optim/dp_ddp_optimizer_automatic_clipping.rst b/website/sphinx/source/optim/dp_ddp_optimizer_automatic_clipping.rst new file mode 100644 index 00000000..7a90a994 --- /dev/null +++ b/website/sphinx/source/optim/dp_ddp_optimizer_automatic_clipping.rst @@ -0,0 +1,8 @@ +DistributedDPAutomaticClippingOptimizer +========================================= + +.. automodule:: opacus.optimizers.ddpoptimizer_automatic_clipping + :members: + :undoc-members: + :show-inheritance: + diff --git a/website/sphinx/source/optim/dp_optimizer_automatic_clipping.rst b/website/sphinx/source/optim/dp_optimizer_automatic_clipping.rst new file mode 100644 index 00000000..2749b5d7 --- /dev/null +++ b/website/sphinx/source/optim/dp_optimizer_automatic_clipping.rst @@ -0,0 +1,8 @@ +DPAutomaticClippingOptimizer +============================== + +.. automodule:: opacus.optimizers.optimizer_automatic_clipping + :members: + :undoc-members: + :show-inheritance: + diff --git a/website/sphinx/source/optim/optimizers.rst b/website/sphinx/source/optim/optimizers.rst index 7bda1032..33780343 100644 --- a/website/sphinx/source/optim/optimizers.rst +++ b/website/sphinx/source/optim/optimizers.rst @@ -4,7 +4,9 @@ Optimizers dp_optimizer dp_optimizer_fast_gradient_clipping + dp_optimizer_automatic_clipping dp_per_layer_optimizer dp_ddp_optimizer dp_ddp_optimizer_fast_gradient_clipping + dp_ddp_optimizer_automatic_clipping dp_ddp_per_layer_optimizer From 89b625504765bac562f7d0cfbcbf24f897aea2c1 Mon Sep 17 00:00:00 2001 From: Parth Shandilya <24358501+ParthS007@users.noreply.github.com> Date: Thu, 30 Oct 2025 10:06:50 +0100 Subject: [PATCH 4/4] Lint and format changes --- opacus/optimizers/__init__.py | 12 ++++++------ opacus/optimizers/ddpoptimizer_automatic_clipping.py | 3 +-- opacus/optimizers/optimizer_automatic_clipping.py | 12 ++++++++---- opacus/tests/multigpu_automatic_clipping_test.py | 12 +++++------- 4 files changed, 20 insertions(+), 19 deletions(-) diff --git a/opacus/optimizers/__init__.py b/opacus/optimizers/__init__.py index 707fa06b..b2d0c819 100644 --- a/opacus/optimizers/__init__.py +++ b/opacus/optimizers/__init__.py @@ -15,21 +15,21 @@ from .adaclipoptimizer import AdaClipDPOptimizer from .ddp_perlayeroptimizer import SimpleDistributedPerLayerOptimizer from .ddpoptimizer import DistributedDPOptimizer +from .ddpoptimizer_automatic_clipping import ( + DistributedDPAutomaticClippingOptimizer, + DistributedDPPerLayerAutomaticClippingOptimizer, +) from .ddpoptimizer_fast_gradient_clipping import ( DistributedDPOptimizerFastGradientClipping, ) from .fsdpoptimizer_fast_gradient_clipping import FSDPOptimizerFastGradientClipping from .optimizer import DPOptimizer -from .optimizer_fast_gradient_clipping import DPOptimizerFastGradientClipping -from .perlayeroptimizer import DPPerLayerOptimizer from .optimizer_automatic_clipping import ( DPAutomaticClippingOptimizer, DPPerLayerAutomaticClippingOptimizer, ) -from .ddpoptimizer_automatic_clipping import ( - DistributedDPAutomaticClippingOptimizer, - DistributedDPPerLayerAutomaticClippingOptimizer, -) +from .optimizer_fast_gradient_clipping import DPOptimizerFastGradientClipping +from .perlayeroptimizer import DPPerLayerOptimizer __all__ = [ diff --git a/opacus/optimizers/ddpoptimizer_automatic_clipping.py b/opacus/optimizers/ddpoptimizer_automatic_clipping.py index 6548fc12..2fcc80f5 100644 --- a/opacus/optimizers/ddpoptimizer_automatic_clipping.py +++ b/opacus/optimizers/ddpoptimizer_automatic_clipping.py @@ -17,12 +17,11 @@ from typing import Callable, List, Optional import torch -from torch.optim import Optimizer - from opacus.optimizers.optimizer_automatic_clipping import ( DPAutomaticClippingOptimizer, DPPerLayerAutomaticClippingOptimizer, ) +from torch.optim import Optimizer class DistributedDPAutomaticClippingOptimizer(DPAutomaticClippingOptimizer): diff --git a/opacus/optimizers/optimizer_automatic_clipping.py b/opacus/optimizers/optimizer_automatic_clipping.py index 5b0c4e8a..2bdbe24f 100644 --- a/opacus/optimizers/optimizer_automatic_clipping.py +++ b/opacus/optimizers/optimizer_automatic_clipping.py @@ -18,9 +18,12 @@ from typing import List import torch -from opacus.optimizers.optimizer import _check_processed_flag, _mark_as_processed +from opacus.optimizers.optimizer import ( + DPOptimizer, + _check_processed_flag, + _mark_as_processed, +) from opacus.optimizers.perlayeroptimizer import DPPerLayerOptimizer -from opacus.optimizers.optimizer import DPOptimizer class DPAutomaticClippingOptimizer(DPOptimizer): @@ -102,8 +105,9 @@ class DPPerLayerAutomaticClippingOptimizer(DPPerLayerOptimizer): Per-layer variant of automatic clipping. For each parameter (layer), we compute the per-sample clip factor using the - corresponding per-layer ``max_grad_norm``: - ``per_sample_clip_factor = max_grad_norm / (per_sample_norms + 0.01)`` + corresponding per-layer ``max_grad_norm``:: + + per_sample_clip_factor = max_grad_norm / (per_sample_norms + 0.01) This allows each layer to have different clipping behavior based on its own gradient magnitude distribution, which can improve training stability and diff --git a/opacus/tests/multigpu_automatic_clipping_test.py b/opacus/tests/multigpu_automatic_clipping_test.py index c57ba53a..f1b49b77 100644 --- a/opacus/tests/multigpu_automatic_clipping_test.py +++ b/opacus/tests/multigpu_automatic_clipping_test.py @@ -115,7 +115,9 @@ def demo_basic(rank, weight, world_size, dp, clipping): if clipping == "automatic": assert isinstance(optimizer, DistributedDPAutomaticClippingOptimizer) elif clipping == "automatic_per_layer": - assert isinstance(optimizer, DistributedDPPerLayerAutomaticClippingOptimizer) + assert isinstance( + optimizer, DistributedDPPerLayerAutomaticClippingOptimizer + ) for x, y in data_loader: outputs = ddp_model(x.to(rank)) @@ -167,9 +169,7 @@ def test_automatic_clipping_gradient_correct(self) -> None: clipping=None, ) - self.assertTrue( - torch.allclose(weight_dp, weight_nodp, atol=1e-5, rtol=1e-3) - ) + self.assertTrue(torch.allclose(weight_dp, weight_nodp, atol=1e-5, rtol=1e-3)) @unittest.skipIf(torch.cuda.device_count() < 2, "Need at least 2 GPUs") def test_automatic_per_layer_clipping_gradient_correct(self) -> None: @@ -199,9 +199,7 @@ def test_automatic_per_layer_clipping_gradient_correct(self) -> None: clipping=None, ) - self.assertTrue( - torch.allclose(weight_dp, weight_nodp, atol=1e-5, rtol=1e-3) - ) + self.assertTrue(torch.allclose(weight_dp, weight_nodp, atol=1e-5, rtol=1e-3)) @unittest.skipIf(torch.cuda.device_count() < 2, "Need at least 2 GPUs") def test_automatic_clipping_optimizer_type(self) -> None: