41
41
42
42
43
43
class SleeperAgentAttack (GradientMatchingAttack ):
44
- # def __init__(self,num):
45
- # GradientMatchingAttack.__init__(self)
46
- # self.variable = num*2
44
+ def __init__ (
45
+ self ,
46
+ classifier : "CLASSIFIER_NEURALNETWORK_TYPE" ,
47
+ percent_poison : float ,
48
+ epsilon : float = 0.1 ,
49
+ max_trials : int = 8 ,
50
+ max_epochs : int = 250 ,
51
+ learning_rate_schedule : Tuple [List [float ], List [int ]] = ([1e-1 , 1e-2 , 1e-3 , 1e-4 ], [100 , 150 , 200 , 220 ]),
52
+ batch_size : int = 128 ,
53
+ clip_values : Tuple [float , float ] = (0 , 1.0 ),
54
+ verbose : int = 1 ,
55
+ indices_target = None ,
56
+ patching_strategy = "random" ,
57
+ selection_strategy = "random" ,
58
+ retraining_factor = 1 ,
59
+ model_retraining = False ,
60
+ model_retraining_epoch = 1
61
+ ):
62
+ super ().__init__ (classifier ,
63
+ percent_poison ,
64
+ epsilon ,
65
+ max_trials ,
66
+ max_epochs ,
67
+ learning_rate_schedule ,
68
+ batch_size ,
69
+ clip_values ,
70
+ verbose )
71
+ self .indices_target = indices_target
72
+ self .selection_strategy = selection_strategy
73
+ self .retraining_factor = retraining_factor
74
+ self .model_retraining = model_retraining
75
+ self .model_retraining_epoch = model_retraining_epoch
76
+ self .indices_poison = []
77
+
78
+
47
79
"""
48
80
Implementation of Sleeper Agent Attack"""
49
81
def poison (
50
82
self , x_trigger : np .ndarray , y_trigger : np .ndarray , x_train : np .ndarray , y_train : np .ndarray
51
- , index_target , s_type = "random" ) -> Tuple [np .ndarray , np .ndarray ]:
83
+ ) -> Tuple [np .ndarray , np .ndarray ]:
52
84
"""
53
85
Optimizes a portion of poisoned samples from x_train to make a model classify x_target
54
86
as y_target by matching the gradients.
@@ -70,7 +102,7 @@ def poison(
70
102
finish_poisoning = self .__finish_poison_pytorch
71
103
else :
72
104
raise NotImplementedError (
73
- "GradientMatchingAttack is currently implemented only for Tensorflow V2 and Pytorch."
105
+ "SleeperAgentAttack is currently implemented only for Tensorflow V2 and Pytorch."
74
106
)
75
107
76
108
# Choose samples to poison.
@@ -92,21 +124,21 @@ def poison(
92
124
else :
93
125
y_train_classes = y_train
94
126
for _ in trange (self .max_trials ):
95
- if s_type == "random" :
96
- # pdb.set_trace()
97
- indices_poison = np .random .permutation (np .where ([y in classes_target for y in y_train_classes ])[0 ])[
98
- :num_poison_samples
99
- ]
127
+ if self .selection_strategy == "random" :
128
+ self .indices_poison = np .random .permutation (np .where ([y in classes_target for y in y_train_classes ])[0 ])[:num_poison_samples ]
100
129
else :
101
- indices_poison = self .select_poison_indices (self .substitute_classifier ,x_train ,y_train ,num_poison_samples )
130
+ self . indices_poison = self .select_poison_indices (self .substitute_classifier ,x_train ,y_train ,num_poison_samples )
102
131
x_poison = x_train [indices_poison ]
103
132
y_poison = y_train [indices_poison ]
104
133
self .__initialize_poison (x_trigger , y_trigger , x_poison , y_poison )
105
- for i in [80 ,80 ,90 ]:
106
- self .max_epochs = i
107
- x_poisoned , B_ = poisoner (x_poison , y_poison ,index_target ,indices_poison )
108
- self .model_retraining (x_poisoned ,index_target ,indices_poison ,40 )
109
- x_poisoned , B_ = poisoner (x_poison , y_poison ,index_target ,indices_poison ) # pylint: disable=C0103
134
+ if self .model_retraining :
135
+ retrain_epochs = self .retraining_factor // self .max_epochs
136
+ for i in range (self .retraining_factor - 1 ):
137
+ self .max_epochs = retrain_epochs
138
+ x_poisoned , B_ = poisoner (x_poison , y_poison )
139
+ self .model_retraining (x_poisoned )
140
+ else :
141
+ x_poisoned , B_ = poisoner (x_poison , y_poison ,index_target ,indices_poison ) # pylint: disable=C0103
110
142
finish_poisoning ()
111
143
B_ = np .mean (B_ ) # Averaging B losses from multiple batches. # pylint: disable=C0103
112
144
if B_ < best_B :
@@ -119,7 +151,7 @@ def poison(
119
151
x_train [best_indices_poison ] = best_x_poisoned
120
152
return x_train , y_train , best_indices_poison
121
153
122
- def model_retraining (self ,poisoned_samples , index_target , indices_poison , epochs ):
154
+ def model_retraining (self ,poisoned_samples ):
123
155
import torch
124
156
from art .utils import load_cifar10
125
157
(x_train , y_train ), (x_test , y_test ), min_ , max_ = load_cifar10 ()
@@ -131,9 +163,9 @@ def model_retraining(self,poisoned_samples,index_target,indices_poison,epochs):
131
163
max_ = (max_ - mean )/ (std + 1e-7 )
132
164
x_train = np .transpose (x_train , [0 , 3 ,1 ,2 ])
133
165
poisoned_samples = np .asarray (poisoned_samples )
134
- x_train [index_target [ indices_poison ]] = poisoned_samples
166
+ x_train [self . indices_target [ self . indices_poison ]] = poisoned_samples
135
167
model ,loss_fn ,optimizer = create_model (x_train , y_train , x_test = x_test , y_test = y_test ,
136
- num_classes = 10 , batch_size = 128 , epochs = 80 )
168
+ num_classes = 10 , batch_size = 128 , epochs = self . model_retraining_epochs )
137
169
model_ = PyTorchClassifier (model , input_shape = x_train .shape [1 :], loss = loss_fn ,
138
170
optimizer = optimizer , nb_classes = 10 )
139
171
check_train = self .substitute_classifier .model .training
@@ -194,7 +226,7 @@ def create_model(x_train, y_train, x_test=None, y_test=None, num_classes=10, bat
194
226
print ("Epoch %d train accuracy: %f" % (epoch , train_accuracy ))
195
227
test_accuracy = testAccuracy (model , dataloader_test )
196
228
print ("Final test accuracy: %f" % test_accuracy )
197
- return model ,loss_fn ,optimizer
229
+ return model ,loss_fn ,optimizer
198
230
199
231
200
232
def select_poison_indices (self ,classifier ,x_samples ,y_samples ,num_poison ):
0 commit comments