Skip to content

Commit 0e76bc5

Browse files
committed
Fixes
Signed-off-by: Kevin Eykholt <[email protected]>
1 parent bb58ab8 commit 0e76bc5

10 files changed

+293
-51
lines changed

art/attacks/poisoning/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@
77
from art.attacks.poisoning.adversarial_embedding_attack import PoisoningAttackAdversarialEmbedding
88
from art.attacks.poisoning.clean_label_backdoor_attack import PoisoningAttackCleanLabelBackdoor
99
from art.attacks.poisoning.bullseye_polytope_attack import BullseyePolytopeAttackPyTorch
10+
from art.attacks.poisoning.hidden_trigger_backdoor.hidden_trigger_backdoor import HiddenTriggerBackdoor
11+
from art.attacks.poisoning.hidden_trigger_backdoor.hidden_trigger_backdoor_pytorch import HiddenTriggerBackdoorPyTorch
12+
from art.attacks.poisoning.hidden_trigger_backdoor.hidden_trigger_backdoor_keras import HiddenTriggerBackdoorKeras
Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +0,0 @@
1-
from art.attacks.poisoning.hidden_trigger_backdoor.hidden_trigger_backdoor import HiddenTriggerBackdoor
2-
from art.attacks.poisoning.hidden_trigger_backdoor.hidden_trigger_backdoor_pytorch import HiddenTriggerBackdoorPyTorch
3-
from art.attacks.poisoning.hidden_trigger_backdoor.hidden_trigger_backdoor_keras import HiddenTriggerBackdoorKeras

art/attacks/poisoning/hidden_trigger_backdoor/hidden_trigger_backdoor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def _check_params(self) -> None:
195195
raise ValueError("Learning rate must be strictly positive")
196196

197197
if not isinstance(self.backdoor, PoisoningAttackBackdoor):
198-
raise ValueError("Backdoor must be of type PoisoningAttackBackdoor")
198+
raise TypeError("Backdoor must be of type PoisoningAttackBackdoor")
199199

200200
if self.eps < 0:
201201
raise ValueError("The perturbation size `eps` has to be non-negative.")

art/attacks/poisoning/hidden_trigger_backdoor/hidden_trigger_backdoor_keras.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,22 @@
2525
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
2626

2727
import numpy as np
28+
import tensorflow as tf
2829
from scipy.spatial import distance
2930
from tqdm.auto import trange
3031

3132
from art.attacks.attack import PoisoningAttackWhiteBox
3233
from art.attacks.poisoning.backdoor_attack import PoisoningAttackBackdoor
3334
from art.estimators import BaseEstimator, NeuralNetworkMixin
3435
from art.estimators.classification.classifier import ClassifierMixin
35-
from art.estimators.classification.keras import KerasClassifier
3636

3737
if TYPE_CHECKING:
38-
from art.utils import CLASSIFIER_NEURALNETWORK_TYPE
38+
from art.estimators.classification.keras import KerasClassifier
3939

4040
logger = logging.getLogger(__name__)
4141

4242

43-
class LossMeter():
43+
class LossMeter:
4444
"""
4545
Computes and stores the average and current loss value
4646
"""
@@ -82,7 +82,7 @@ class HiddenTriggerBackdoorKeras(PoisoningAttackWhiteBox):
8282

8383
attack_params = PoisoningAttackWhiteBox.attack_params + ["target"]
8484

85-
_estimator_requirements = (BaseEstimator, NeuralNetworkMixin, ClassifierMixin, KerasClassifier)
85+
_estimator_requirements = (BaseEstimator, NeuralNetworkMixin, ClassifierMixin)
8686

8787
def __init__(
8888
self,
@@ -202,10 +202,10 @@ def poison( # pylint: disable=W0221
202202
"There must be at least as many images with the source label as the target. Maybe try reducing poison_percent or providing fewer target indices"
203203
)
204204

205-
logger.info("Number of poison inputs: %d", num_poison_img)
206-
logger.info("Number of trigger inputs: %d", num_trigger_img)
205+
logger.info("Number of poison inputs: %d", num_poison)
206+
logger.info("Number of trigger inputs: %d", num_trigger)
207207

208-
batches = int(np.ceil(num_poison_img / float(self.batch_size)))
208+
batches = int(np.ceil(num_poison / float(self.batch_size)))
209209

210210
losses = LossMeter()
211211
final_poison = np.copy(data[poison_indices])
@@ -236,22 +236,22 @@ def poison( # pylint: disable=W0221
236236
trigger_samples, self.feature_layer, 1, framework=True
237237
)
238238

239-
attack_loss = tf.norm(poison_features-trigger_features, ord=2)
239+
attack_loss = tf.norm(poison_features - trigger_features, ord=2)
240240

241241
trigger_features = self.estimator.get_activations(trigger_samples, self.feature_layer, 1)
242242

243243
for i in range(self.max_iter):
244244
learning_rate = self.learning_rate * (self.decay_coeff ** (i // self.decay_iter))
245245

246-
poison_features = self.estimator.get_activations(poison_samples + poison, self.feature_layer, 1)
246+
poison_features = self.estimator.get_activations(poison_samples, self.feature_layer, 1)
247247

248248
# Compute distance between features and match samples
249249
# We are swapping the samples and the features unlike in the original implementation because
250250
# we are computing the loss gradient using ART, which needs the inputs rather than the features
251251
trigger_samples_copy = np.copy(trigger_samples)
252252
trigger_features_copy = np.copy(trigger_features) # Assuming this is numpy array
253253
dist = distance.cdist(trigger_features, poison_features)
254-
for _ in range(len(source_features)):
254+
for _ in range(len(trigger_features)):
255255
min_index = np.squeeze((dist == np.min(dist)).nonzero())
256256
trigger_samples[min_index[1]] = trigger_samples_copy[min_index[0]]
257257
trigger_features[min_index[1]] = trigger_features_copy[min_index[0]]
@@ -260,7 +260,7 @@ def poison( # pylint: disable=W0221
260260
loss = np.linalg.norm(trigger_features - poison_features)
261261
print(loss)
262262
losses.update(loss, len(trigger_samples))
263-
263+
264264
(attack_grad,) = self.estimator.custom_loss_gradient(
265265
attack_loss,
266266
[poison_placeholder, trigger_placeholder],

art/attacks/poisoning/hidden_trigger_backdoor/hidden_trigger_backdoor_pytorch.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,21 @@
2424
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
2525

2626
import numpy as np
27-
import torch
2827
from tqdm.auto import trange
2928

3029
from art.attacks.attack import PoisoningAttackWhiteBox
3130
from art.attacks.poisoning.backdoor_attack import PoisoningAttackBackdoor
3231
from art.estimators import BaseEstimator, NeuralNetworkMixin
3332
from art.estimators.classification.classifier import ClassifierMixin
34-
from art.estimators.classification.pytorch import PyTorchClassifier
3533

3634
if TYPE_CHECKING:
37-
from art.utils import CLASSIFIER_NEURALNETWORK_TYPE
35+
# pylint: disable=C0412
36+
from art.estimators.classification.pytorch import PyTorchClassifier
3837

3938
logger = logging.getLogger(__name__)
4039

4140

42-
class LossMeter():
41+
class LossMeter:
4342
"""
4443
Computes and stores the average and current loss value
4544
"""
@@ -120,7 +119,7 @@ def __init__(
120119
:param batch_size: The number of samples to draw per batch.
121120
:param poison_percent: The percentage of the data to poison. This is ignored if indices are provided
122121
for the source parameter
123-
:param is_index: If true, the source and target params are assumed to represent indices rather than a class label.
122+
:param is_index: If true, the source and target params are assumed to represent indices rather than a class label.
124123
poison_percent is ignored if true
125124
:param verbose: Show progress bars.
126125
"""
@@ -151,6 +150,7 @@ def poison( # pylint: disable=W0221
151150
:param y: The labels of the provided samples. If none, we will use the classifier to label the data.
152151
:return: An tuple holding the `(poisoning_examples, poisoning_labels)`.
153152
"""
153+
import torch # lgtm [py/repeated-import]
154154

155155
data = np.copy(x)
156156
estimated_labels = self.classifier.predict(data) if y is None else np.copy(y)

art/estimators/classification/pytorch.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,14 @@ def _predict_framework(self, x: "torch.Tensor") -> "torch.Tensor":
351351

352352
return output
353353

354-
def fit(
355-
self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10, train=True, **kwargs
354+
def fit( # pylint: disable=W0221
355+
self,
356+
x: np.ndarray,
357+
y: np.ndarray,
358+
batch_size: int = 128,
359+
nb_epochs: int = 10,
360+
training_mode: bool = True,
361+
**kwargs
356362
) -> None:
357363
"""
358364
Fit the classifier on the training set `(x, y)`.
@@ -362,17 +368,14 @@ def fit(
362368
shape (nb_samples,).
363369
:param batch_size: Size of batches.
364370
:param nb_epochs: Number of epochs to use for training.
365-
:param train: Boolean indiciating if the model should be set to training model
371+
:param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
366372
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
367373
and providing it takes no effect.
368374
"""
369375
import torch # lgtm [py/repeated-import]
370376

371-
# Put the model in the training mode
372-
if train:
373-
self._model.train()
374-
else:
375-
self._model.eval()
377+
# Set model mode
378+
self._model.train(mode=training_mode)
376379

377380
if self._optimizer is None: # pragma: no cover
378381
raise ValueError("An optimizer is needed to train the model, but none for provided.")

notebooks/art-for-tensorflow-v2-keras.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -388,9 +388,9 @@
388388
],
389389
"metadata": {
390390
"kernelspec": {
391-
"display_name": "py37_tf220",
391+
"display_name": "Python 3 (ipykernel)",
392392
"language": "python",
393-
"name": "py37_tf220"
393+
"name": "python3"
394394
},
395395
"language_info": {
396396
"codemirror_mode": {
@@ -402,7 +402,7 @@
402402
"name": "python",
403403
"nbconvert_exporter": "python",
404404
"pygments_lexer": "ipython3",
405-
"version": "3.7.6"
405+
"version": "3.9.7"
406406
}
407407
},
408408
"nbformat": 4,

notebooks/poisoning_attack_backdoor_image.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@
277277
],
278278
"metadata": {
279279
"kernelspec": {
280-
"display_name": "Python 3",
280+
"display_name": "Python 3 (ipykernel)",
281281
"language": "python",
282282
"name": "python3"
283283
},
@@ -291,7 +291,7 @@
291291
"name": "python",
292292
"nbconvert_exporter": "python",
293293
"pygments_lexer": "ipython3",
294-
"version": "3.6.10"
294+
"version": "3.9.7"
295295
}
296296
},
297297
"nbformat": 4,

0 commit comments

Comments
 (0)