Skip to content

Commit 3f438f2

Browse files
committed
Changes related to sleeper agent
Signed-off-by: Shriti Priya <[email protected]>
1 parent 7f625e5 commit 3f438f2

File tree

3 files changed

+192
-1404
lines changed

3 files changed

+192
-1404
lines changed

art/attacks/poisoning/sleeper_agent_attack.py

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,46 @@
4141

4242

4343
class SleeperAgentAttack(GradientMatchingAttack):
44-
# def __init__(self,num):
45-
# GradientMatchingAttack.__init__(self)
46-
# self.variable = num*2
44+
def __init__(
45+
self,
46+
classifier: "CLASSIFIER_NEURALNETWORK_TYPE",
47+
percent_poison: float,
48+
epsilon: float = 0.1,
49+
max_trials: int = 8,
50+
max_epochs: int = 250,
51+
learning_rate_schedule: Tuple[List[float], List[int]] = ([1e-1, 1e-2, 1e-3, 1e-4], [100, 150, 200, 220]),
52+
batch_size: int = 128,
53+
clip_values: Tuple[float, float] = (0, 1.0),
54+
verbose: int = 1,
55+
indices_target = None,
56+
patching_strategy = "random",
57+
selection_strategy = "random",
58+
retraining_factor = 1,
59+
model_retraining = False,
60+
model_retraining_epoch = 1
61+
):
62+
super().__init__(classifier,
63+
percent_poison,
64+
epsilon,
65+
max_trials,
66+
max_epochs,
67+
learning_rate_schedule,
68+
batch_size,
69+
clip_values,
70+
verbose)
71+
self.indices_target = indices_target
72+
self.selection_strategy = selection_strategy
73+
self.retraining_factor = retraining_factor
74+
self.model_retraining = model_retraining
75+
self.model_retraining_epoch = model_retraining_epoch
76+
self.indices_poison = []
77+
78+
4779
"""
4880
Implementation of Sleeper Agent Attack"""
4981
def poison(
5082
self, x_trigger: np.ndarray, y_trigger: np.ndarray, x_train: np.ndarray, y_train: np.ndarray
51-
,index_target,s_type="random") -> Tuple[np.ndarray, np.ndarray]:
83+
) -> Tuple[np.ndarray, np.ndarray]:
5284
"""
5385
Optimizes a portion of poisoned samples from x_train to make a model classify x_target
5486
as y_target by matching the gradients.
@@ -70,7 +102,7 @@ def poison(
70102
finish_poisoning = self.__finish_poison_pytorch
71103
else:
72104
raise NotImplementedError(
73-
"GradientMatchingAttack is currently implemented only for Tensorflow V2 and Pytorch."
105+
"SleeperAgentAttack is currently implemented only for Tensorflow V2 and Pytorch."
74106
)
75107

76108
# Choose samples to poison.
@@ -92,21 +124,21 @@ def poison(
92124
else:
93125
y_train_classes = y_train
94126
for _ in trange(self.max_trials):
95-
if s_type == "random":
96-
# pdb.set_trace()
97-
indices_poison = np.random.permutation(np.where([y in classes_target for y in y_train_classes])[0])[
98-
:num_poison_samples
99-
]
127+
if self.selection_strategy == "random":
128+
self.indices_poison = np.random.permutation(np.where([y in classes_target for y in y_train_classes])[0])[:num_poison_samples]
100129
else:
101-
indices_poison = self.select_poison_indices(self.substitute_classifier,x_train,y_train,num_poison_samples)
130+
self.indices_poison = self.select_poison_indices(self.substitute_classifier,x_train,y_train,num_poison_samples)
102131
x_poison = x_train[indices_poison]
103132
y_poison = y_train[indices_poison]
104133
self.__initialize_poison(x_trigger, y_trigger, x_poison, y_poison)
105-
for i in [80,80,90]:
106-
self.max_epochs = i
107-
x_poisoned, B_ = poisoner(x_poison, y_poison,index_target,indices_poison)
108-
self.model_retraining(x_poisoned,index_target,indices_poison,40)
109-
x_poisoned, B_ = poisoner(x_poison, y_poison,index_target,indices_poison) # pylint: disable=C0103
134+
if self.model_retraining:
135+
retrain_epochs = self.retraining_factor//self.max_epochs
136+
for i in range(self.retraining_factor-1):
137+
self.max_epochs = retrain_epochs
138+
x_poisoned, B_ = poisoner(x_poison, y_poison)
139+
self.model_retraining(x_poisoned)
140+
else:
141+
x_poisoned, B_ = poisoner(x_poison, y_poison,index_target,indices_poison) # pylint: disable=C0103
110142
finish_poisoning()
111143
B_ = np.mean(B_) # Averaging B losses from multiple batches. # pylint: disable=C0103
112144
if B_ < best_B:
@@ -119,7 +151,7 @@ def poison(
119151
x_train[best_indices_poison] = best_x_poisoned
120152
return x_train, y_train, best_indices_poison
121153

122-
def model_retraining(self,poisoned_samples,index_target,indices_poison,epochs):
154+
def model_retraining(self,poisoned_samples):
123155
import torch
124156
from art.utils import load_cifar10
125157
(x_train, y_train), (x_test, y_test), min_, max_ = load_cifar10()
@@ -131,9 +163,9 @@ def model_retraining(self,poisoned_samples,index_target,indices_poison,epochs):
131163
max_ = (max_-mean)/(std+1e-7)
132164
x_train = np.transpose(x_train, [0, 3,1,2])
133165
poisoned_samples = np.asarray(poisoned_samples)
134-
x_train[index_target[indices_poison]] = poisoned_samples
166+
x_train[self.indices_target[self.indices_poison]] = poisoned_samples
135167
model,loss_fn,optimizer = create_model(x_train, y_train, x_test=x_test, y_test=y_test,
136-
num_classes=10, batch_size=128, epochs=80)
168+
num_classes=10, batch_size=128, epochs=self.model_retraining_epochs)
137169
model_ = PyTorchClassifier(model, input_shape=x_train.shape[1:], loss=loss_fn,
138170
optimizer=optimizer, nb_classes=10)
139171
check_train = self.substitute_classifier.model.training
@@ -194,7 +226,7 @@ def create_model(x_train, y_train, x_test=None, y_test=None, num_classes=10, bat
194226
print("Epoch %d train accuracy: %f" % (epoch, train_accuracy))
195227
test_accuracy = testAccuracy(model, dataloader_test)
196228
print("Final test accuracy: %f" % test_accuracy)
197-
return model,loss_fn,optimizer
229+
return model,loss_fn,optimizer
198230

199231

200232
def select_poison_indices(self,classifier,x_samples,y_samples,num_poison):

0 commit comments

Comments
 (0)