Skip to content

Commit cb4deda

Browse files
authored
Merge pull request #2218 from Ethos-lab/randomized-smoothing-additions
Randomized Smoothing Variations Implementation
2 parents 82f8fa2 + 63612d2 commit cb4deda

File tree

18 files changed

+2002
-253
lines changed

18 files changed

+2002
-253
lines changed

art/estimators/certification/randomized_smoothing/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,10 @@
44
from art.estimators.certification.randomized_smoothing.randomized_smoothing import RandomizedSmoothingMixin
55

66
from art.estimators.certification.randomized_smoothing.numpy import NumpyRandomizedSmoothing
7-
from art.estimators.certification.randomized_smoothing.tensorflow import TensorFlowV2RandomizedSmoothing
87
from art.estimators.certification.randomized_smoothing.pytorch import PyTorchRandomizedSmoothing
8+
from art.estimators.certification.randomized_smoothing.tensorflow import TensorFlowV2RandomizedSmoothing
9+
from art.estimators.certification.randomized_smoothing.smooth_mix.pytorch import PyTorchSmoothMix
10+
from art.estimators.certification.randomized_smoothing.macer.pytorch import PyTorchMACER
11+
from art.estimators.certification.randomized_smoothing.macer.tensorflow import TensorFlowV2MACER
12+
from art.estimators.certification.randomized_smoothing.smooth_adv.pytorch import PyTorchSmoothAdv
13+
from art.estimators.certification.randomized_smoothing.smooth_adv.tensorflow import TensorFlowV2SmoothAdv

art/estimators/certification/randomized_smoothing/macer/__init__.py

Whitespace-only changes.
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2023
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
18+
"""
19+
This module implements MACER applied to classifier predictions.
20+
21+
| Paper link: https://arxiv.org/abs/2001.02378
22+
"""
23+
from __future__ import absolute_import, division, print_function, unicode_literals
24+
25+
import logging
26+
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
27+
28+
from tqdm.auto import trange
29+
import numpy as np
30+
31+
from art.estimators.certification.randomized_smoothing.pytorch import PyTorchRandomizedSmoothing
32+
from art.utils import check_and_transform_label_format
33+
34+
if TYPE_CHECKING:
35+
# pylint: disable=C0412
36+
import torch
37+
from art.utils import CLIP_VALUES_TYPE, PREPROCESSING_TYPE
38+
from art.defences.preprocessor import Preprocessor
39+
from art.defences.postprocessor import Postprocessor
40+
41+
logger = logging.getLogger(__name__)
42+
43+
44+
class PyTorchMACER(PyTorchRandomizedSmoothing):
45+
"""
46+
Implementation of MACER training, as introduced in Zhai et al. (2020)
47+
48+
| Paper link: https://arxiv.org/abs/2001.02378
49+
"""
50+
51+
estimator_params = PyTorchRandomizedSmoothing.estimator_params + [
52+
"beta",
53+
"gamma",
54+
"lmbda",
55+
"gauss_num",
56+
]
57+
58+
def __init__(
59+
self,
60+
model: "torch.nn.Module",
61+
loss: "torch.nn.modules.loss._Loss",
62+
input_shape: Tuple[int, ...],
63+
nb_classes: int,
64+
optimizer: Optional["torch.optim.Optimizer"] = None,
65+
channels_first: bool = True,
66+
clip_values: Optional["CLIP_VALUES_TYPE"] = None,
67+
preprocessing_defences: Union["Preprocessor", List["Preprocessor"], None] = None,
68+
postprocessing_defences: Union["Postprocessor", List["Postprocessor"], None] = None,
69+
preprocessing: "PREPROCESSING_TYPE" = (0.0, 1.0),
70+
device_type: str = "gpu",
71+
sample_size: int = 32,
72+
scale: float = 0.1,
73+
alpha: float = 0.001,
74+
beta: float = 16.0,
75+
gamma: float = 8.0,
76+
lmbda: float = 12.0,
77+
gaussian_samples: int = 16,
78+
verbose: bool = False,
79+
) -> None:
80+
"""
81+
Create a MACER classifier.
82+
83+
:param model: PyTorch model. The output of the model can be logits, probabilities or anything else. Logits
84+
output should be preferred where possible to ensure attack efficiency.
85+
:param loss: The loss function for which to compute gradients for training. The target label must be raw
86+
categorical, i.e. not converted to one-hot encoding.
87+
:param input_shape: The shape of one input instance.
88+
:param nb_classes: The number of classes of the model.
89+
:param optimizer: The optimizer used to train the classifier.
90+
:param channels_first: Set channels first or last.
91+
:param clip_values: Tuple of the form `(min, max)` of floats or `np.ndarray` representing the minimum and
92+
maximum values allowed for features. If floats are provided, these will be used as the range of all
93+
features. If arrays are provided, each value will be considered the bound for a feature, thus
94+
the shape of clip values needs to match the total number of features.
95+
:param preprocessing_defences: Preprocessing defence(s) to be applied by the classifier.
96+
:param postprocessing_defences: Postprocessing defence(s) to be applied by the classifier.
97+
:param preprocessing: Tuple of the form `(subtrahend, divisor)` of floats or `np.ndarray` of values to be
98+
used for data preprocessing. The first value will be subtracted from the input. The input will then
99+
be divided by the second one.
100+
:param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`.
101+
:param sample_size: Number of samples for smoothing.
102+
:param scale: Standard deviation of Gaussian noise added.
103+
:param alpha: The failure probability of smoothing.
104+
:param beta: The inverse temperature.
105+
:param gamma: The hinge factor.
106+
:param lmbda: The trade-off factor.
107+
:param gaussian_samples: The number of gaussian samples per input.
108+
:param verbose: Show progress bars.
109+
"""
110+
super().__init__(
111+
model=model,
112+
loss=loss,
113+
input_shape=input_shape,
114+
nb_classes=nb_classes,
115+
optimizer=optimizer,
116+
channels_first=channels_first,
117+
clip_values=clip_values,
118+
preprocessing_defences=preprocessing_defences,
119+
postprocessing_defences=postprocessing_defences,
120+
preprocessing=preprocessing,
121+
device_type=device_type,
122+
sample_size=sample_size,
123+
scale=scale,
124+
alpha=alpha,
125+
verbose=verbose,
126+
)
127+
self.beta = beta
128+
self.gamma = gamma
129+
self.lmbda = lmbda
130+
self.gaussian_samples = gaussian_samples
131+
132+
def fit( # pylint: disable=W0221
133+
self,
134+
x: np.ndarray,
135+
y: np.ndarray,
136+
batch_size: int = 128,
137+
nb_epochs: int = 10,
138+
training_mode: bool = True,
139+
drop_last: bool = False,
140+
scheduler: Optional["torch.optim.lr_scheduler._LRScheduler"] = None,
141+
**kwargs,
142+
) -> None:
143+
"""
144+
Fit the classifier on the training set `(x, y)`.
145+
146+
:param x: Training data.
147+
:param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels of
148+
shape (nb_samples,).
149+
:param batch_size: Size of batches.
150+
:param nb_epochs: Number of epochs to use for training.
151+
:param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
152+
:param drop_last: Set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by
153+
the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then
154+
the last batch will be smaller. (default: ``False``)
155+
:param scheduler: Learning rate scheduler to run at the start of every epoch.
156+
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
157+
and providing it takes no effect.
158+
"""
159+
import torch
160+
import torch.nn.functional as F
161+
from torch.utils.data import TensorDataset, DataLoader
162+
163+
# Set model mode
164+
self._model.train(mode=training_mode)
165+
166+
if self._optimizer is None: # pragma: no cover
167+
raise ValueError("An optimizer is needed to train the model, but none for provided")
168+
169+
y = check_and_transform_label_format(y, nb_classes=self.nb_classes)
170+
171+
# Apply preprocessing
172+
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=True)
173+
174+
# Check label shape
175+
y_preprocessed = self.reduce_labels(y_preprocessed)
176+
177+
# Create dataloader
178+
x_tensor = torch.from_numpy(x_preprocessed)
179+
y_tensor = torch.from_numpy(y_preprocessed)
180+
dataset = TensorDataset(x_tensor, y_tensor)
181+
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last)
182+
183+
m = torch.distributions.normal.Normal(
184+
torch.tensor([0.0], device=self.device), torch.tensor([1.0], device=self.device)
185+
)
186+
187+
# Start training
188+
for _ in trange(nb_epochs, disable=not self.verbose):
189+
for x_batch, y_batch in dataloader:
190+
# Move inputs to GPU
191+
x_batch = x_batch.to(self.device)
192+
y_batch = y_batch.to(self.device)
193+
194+
input_size = len(x_batch)
195+
196+
# Tile samples for Gaussian augmentation
197+
new_shape = [input_size * self.gaussian_samples]
198+
new_shape.extend(x_batch[0].shape)
199+
x_batch = x_batch.repeat((1, self.gaussian_samples, 1, 1)).view(new_shape)
200+
201+
# Add random noise for randomized smoothing
202+
noise = torch.randn_like(x_batch, device=self.device) * self.scale
203+
noisy_inputs = x_batch + noise
204+
205+
# Get model outputs
206+
outputs = self.model(noisy_inputs)
207+
outputs = outputs.reshape((input_size, self.gaussian_samples, self.nb_classes))
208+
209+
# Classification loss
210+
outputs_softmax = F.softmax(outputs, dim=2).mean(dim=1)
211+
outputs_log_softmax = torch.log(outputs_softmax + 1e-10)
212+
classification_loss = F.nll_loss(outputs_log_softmax, y_batch, reduction="sum")
213+
214+
# Robustness loss
215+
beta_outputs = outputs * self.beta
216+
beta_outputs_softmax = F.softmax(beta_outputs, dim=2).mean(dim=1)
217+
top2_score, top2_idx = torch.topk(beta_outputs_softmax, 2)
218+
indices_correct = top2_idx[:, 0] == y_batch
219+
out0, out1 = top2_score[indices_correct, 0], top2_score[indices_correct, 1]
220+
robustness_loss = m.icdf(out1) - m.icdf(out0)
221+
indices = (
222+
~torch.isnan(robustness_loss)
223+
& ~torch.isinf(robustness_loss)
224+
& (torch.abs(robustness_loss) <= self.gamma)
225+
)
226+
out0, out1 = out0[indices], out1[indices]
227+
robustness_loss = m.icdf(out1) - m.icdf(out0) + self.gamma
228+
robustness_loss = torch.sum(robustness_loss) * self.scale / 2
229+
230+
# Final objective function
231+
loss = classification_loss + self.lmbda * robustness_loss
232+
loss /= input_size
233+
self._optimizer.zero_grad()
234+
loss.backward()
235+
self._optimizer.step()
236+
237+
if scheduler is not None:
238+
scheduler.step()

0 commit comments

Comments
 (0)