Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion opacus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,20 @@
# limitations under the License.

from . import utils
from .grad_sample import GradSampleModule, GradSampleModuleFastGradientClipping
from .grad_sample import (
GradSampleController,
GradSampleModule,
GradSampleModuleFastGradientClipping,
)
from .privacy_engine import PrivacyEngine
from .privacy_engine_gsc import PrivacyEngineGradSampleController
from .version import __version__


__all__ = [
"PrivacyEngine",
"PrivacyEngineGradSampleController",
"GradSampleController",
"GradSampleModule",
"GradSampleModuleFastGradientClipping",
"utils",
Expand Down
82 changes: 51 additions & 31 deletions opacus/grad_sample/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,41 @@
Computing per sample gradients is an integral part of Opacus framework. We strive to provide out-of-the-box support for
wide range of models, while keeping computations efficient.

We currently provide two independent approaches for computing per sample gradients: hooks-based ``GradSampleModule``
(stable implementation, exists since the very first version of Opacus) and ``GradSampleModuleExpandedWeights``
(based on a beta functionality available in PyTorch 1.12).
We currently provide three independent approaches for computing per sample gradients:

Each of the two implementations comes with it's own set of limitations, and we leave the choice up to the client
which one to use.
1. **Hooks-based `GradSampleModule`** (stable, wraps the model)
2. **`GradSampleController`** (stable, no model wrapping - recommended for transformers)
3. **`GradSampleModuleExpandedWeights`** (beta, based on PyTorch 1.12+ functionality)

``GradSampleModuleExpandedWeights`` is currently in early beta and can produce unexpected errors, but potentially
improves upon ``GradSampleModule`` on performance and functionality.
Each implementation comes with its own set of limitations and benefits.

**TL;DR:** If you want stable implementation, use ``GradSampleModule`` (`grad_sample_mode="hooks"`).
If you want to experiment with the new functionality, you have two options. Try
``GradSampleModuleExpandedWeights``(`grad_sample_mode="ew"`) for better performance and `grad_sample_mode=functorch`
if your model is not supported by ``GradSampleModule``.
**TL;DR:**
- Use `GradSampleModule` (`grad_sample_mode="hooks"`) for stable implementation with standard models
- Use `GradSampleController` via `PrivacyEngineGradSampleController` for transformer models and when you need direct model access without wrapping
- Use `GradSampleModuleExpandedWeights` (`grad_sample_mode="ew"`) if you want to experiment with better performance
- Use `grad_sample_mode="functorch"` if your model has unsupported layers

Please switch back to ``GradSampleModule``(`grad_sample_mode="hooks"`) if you encounter strange errors or unexpexted behaviour.
We'd also appreciate it if you report these to us
Please report any strange errors or unexpected behaviour to us!

## Hooks-based approach
## GradSampleController approach (No Model Wrapping)
- Controller class: ``opacus.grad_sample.GradSampleController``
- Privacy Engine: ``opacus.privacy_engine_gsc.PrivacyEngineGradSampleController``
- Usage: Use `PrivacyEngineGradSampleController` instead of `PrivacyEngine`

**Recommended for transformer models and when model wrapping causes issues.**

Computes per-sample gradients by attaching hooks directly to model parameters without wrapping the model in a
`GradSampleModule`. This approach:

- ✅ Preserves model type (e.g., `isinstance(model, BertModel)` remains `True`)
- ✅ No `_module.` prefix in state_dict
- ✅ Direct access to model attributes (no attribute forwarding needed)
- ✅ Better compatibility with HuggingFace transformers and models with custom `__getattr__`
- ✅ Same grad sampler methods as `GradSampleModule`

See [CONTROLLER_BASED_PRIVACY_ENGINE.md](../../docs/CONTROLLER_BASED_PRIVACY_ENGINE.md) for detailed documentation.

## Hooks-based approach (Model Wrapping)
- Model wrapping class: ``opacus.grad_sample.grad_sample_module.GradSampleModule``
- Keyword argument for ``PrivacyEngine.make_private()``: `grad_sample_mode="hooks"`

Expand Down Expand Up @@ -62,23 +78,27 @@ is roughly the same.
Please note that these are known limitations and we plan to improve Expanded Weights and bridge the gap in feature completeness


| xxx | Hooks | Expanded Weights | Functorch |
|:----------------------------:|:-------------------------------:|:----------------:|:------------:|
| Required PyTorch version | 1.8+ | 1.13+ | 1.12 (to be updated) |
| Development status | Underlying mechanism deprecated | Beta | Beta |
| Runtime Performance† | baseline | ✅ ~25% faster | 🟨 0-50% slower |
| Any DP-allowed†† layers | Not supported | Not supported | ✅ Supported |
| Most popular nn.* layers | ✅ Supported | ✅ Supported | ✅ Supported |
| torchscripted models | Not supported | ✅ Supported | Not supported |
| Client-provided grad sampler | ✅ Supported | Not supported | ✅ Not needed |
| `batch_first=False` | ✅ Supported | Not supported | ✅ Supported |
| Recurrent networks | ✅ Supported | Not supported | ✅ Supported |
| Padding `same` in Conv | ✅ Supported | Not supported | ✅ Supported |
| Empty poisson batches | ✅ Supported | Not supported | Not supported |

† Note, that performance differences are unstable and can vary a lot depending on the exact model and batch size.
Numbers above are averaged over benchmarks with small models consisting of convolutional and linear layers.
Note, that performance differences are only observed on GPU training, CPU performance seem to be almost identical
| xxx | GradSampleModule (Hooks) | GradSampleController | Expanded Weights | Functorch |
|:----------------------------:|:------------------------:|:-------------------:|:----------------:|:------------:|
| Required PyTorch version | 1.8+ | 1.8+ | 1.13+ | 1.12 (to be updated) |
| Development status | Deprecated mechanism | ✅ Stable | Beta | Beta |
| Model wrapping | ✅ Wraps model | ✅ No wrapping | ✅ Wraps model | ✅ Wraps model |
| Runtime Performance† | baseline | baseline | ✅ ~25% faster | 🟨 0-50% slower |
| Transformer compatibility | 🟨 May have issues | ✅ Excellent | 🟨 May have issues | 🟨 May have issues |
| State dict compatibility | 🟨 `_module.` prefix | ✅ Clean keys | 🟨 `_module.` prefix | 🟨 `_module.` prefix |
| Type preservation | ❌ Model wrapped | ✅ Model unchanged | ❌ Model wrapped | ❌ Model wrapped |
| Any DP-allowed†† layers | Not supported | Not supported | Not supported | ✅ Supported |
| Most popular nn.* layers | ✅ Supported | ✅ Supported | ✅ Supported | ✅ Supported |
| torchscripted models | Not supported | Not supported | ✅ Supported | Not supported |
| Client-provided grad sampler | ✅ Supported | ✅ Supported | Not supported | ✅ Not needed |
| `batch_first=False` | ✅ Supported | ✅ Supported | Not supported | ✅ Supported |
| Recurrent networks | ✅ Supported | ✅ Supported | Not supported | ✅ Supported |
| Padding `same` in Conv | ✅ Supported | ✅ Supported | Not supported | ✅ Supported |
| Empty poisson batches | ✅ Supported | ✅ Supported | Not supported | Not supported |

† Note, that performance differences are unstable and can vary a lot depending on the exact model and batch size.
Numbers above are averaged over benchmarks with small models consisting of convolutional and linear layers.
Note, that performance differences are only observed on GPU training, CPU performance seem to be almost identical
for all approaches.

†† Layers that produce joint computations on batch samples (e.g. BatchNorm) are not allowed under any approach
Expand Down
2 changes: 2 additions & 0 deletions opacus/grad_sample/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .dp_rnn import compute_rnn_linear_grad_sample # noqa
from .embedding import compute_embedding_grad_sample # noqa
from .embedding_norm_sample import compute_embedding_norm_sample # noqa
from .grad_sample_controller import GradSampleController # noqa
from .grad_sample_module import GradSampleModule, create_or_accumulate_grad_sample
from .grad_sample_module_fast_gradient_clipping import ( # noqa
GradSampleModuleFastGradientClipping,
Expand Down Expand Up @@ -45,6 +46,7 @@


__all__ = [
"GradSampleController",
"GradSampleModule",
"GradSampleModuleFastGradientClipping",
"GradSampleModuleFastGradientClippingFSDP",
Expand Down
Loading