@@ -106,25 +106,25 @@ def copyParam(self, daeLayers):
106
106
every = 3
107
107
# input layer
108
108
# copy encoder weight
109
- self .encoder [0 ].weight .copy_ (daeLayers [l ].weight )
110
- self .encoder [0 ].bias .copy_ (daeLayers [l ].bias )
111
- self ._dec .weight .copy_ (daeLayers [l ].deweight )
112
- self ._dec .bias .copy_ (daeLayers [l ].vbias )
109
+ self .encoder [0 ].weight .data . copy_ (daeLayers [0 ].weight . data )
110
+ self .encoder [0 ].bias .data . copy_ (daeLayers [0 ].bias . data )
111
+ self ._dec .weight .data . copy_ (daeLayers [0 ].deweight . data )
112
+ self ._dec .bias .data . copy_ (daeLayers [0 ].vbias . data )
113
113
114
114
for l in range (1 , len (self .layers )- 2 ):
115
115
# copy encoder weight
116
- self .encoder [l * every ].weight .copy_ (daeLayers [l ].weight )
117
- self .encoder [l * every ].bias .copy_ (daeLayers [l ].bias )
116
+ self .encoder [l * every ].weight .data . copy_ (daeLayers [l ].weight . data )
117
+ self .encoder [l * every ].bias .data . copy_ (daeLayers [l ].bias . data )
118
118
119
119
# copy decoder weight
120
- self .decoder [- (l - 1 )* every - 1 ].weight .copy_ (daeLayers [l ].deweight )
121
- self .decoder [- (l - 1 )* every - 1 ].bias .copy_ (daeLayers [l ].vbias )
120
+ self .decoder [- (l - 1 )* every - 2 ].weight .data . copy_ (daeLayers [l ].deweight . data )
121
+ self .decoder [- (l - 1 )* every - 2 ].bias .data . copy_ (daeLayers [l ].vbias . data )
122
122
123
123
# z layer
124
- self ._enc_mu .weight .copy_ (daeLayers [- 1 ].weight )
125
- self ._enc_mu .bias .copy_ (daeLayers [- 1 ].bias )
126
- self .decoder [0 ].weight .copy_ (daeLayers [- 1 ].deweight )
127
- self .decoder [0 ].bias .copy_ (daeLayers [- 1 ].vbias )
124
+ self ._enc_mu .weight .data . copy_ (daeLayers [- 1 ].weight . data )
125
+ self ._enc_mu .bias .data . copy_ (daeLayers [- 1 ].bias . data )
126
+ self .decoder [0 ].weight .data . copy_ (daeLayers [- 1 ].deweight . data )
127
+ self .decoder [0 ].bias .data . copy_ (daeLayers [- 1 ].vbias . data )
128
128
129
129
def fit (self , trainloader , validloader , lr = 0.001 , num_epochs = 10 , corrupt = 0.3 ,
130
130
loss_type = "mse" ):
@@ -135,7 +135,7 @@ def fit(self, trainloader, validloader, lr=0.001, num_epochs=10, corrupt=0.3,
135
135
use_cuda = torch .cuda .is_available ()
136
136
if use_cuda :
137
137
self .cuda ()
138
- print ("=====Denoising Autoencoding layer=======" )
138
+ print ("=====Stacked Denoising Autoencoding layer=======" )
139
139
optimizer = optim .Adam (filter (lambda p : p .requires_grad , self .parameters ()), lr = lr )
140
140
if loss_type == "mse" :
141
141
criterion = MSELoss ()
@@ -150,11 +150,7 @@ def fit(self, trainloader, validloader, lr=0.001, num_epochs=10, corrupt=0.3,
150
150
if use_cuda :
151
151
inputs = inputs .cuda ()
152
152
inputs = Variable (inputs )
153
- hidden = self .encode (inputs )
154
- if loss_type == "cross-entropy" :
155
- outputs = self .decode (hidden , binary = True )
156
- else :
157
- outputs = self .decode (hidden )
153
+ z , outputs = self .forward (inputs )
158
154
159
155
valid_recon_loss = criterion (outputs , inputs )
160
156
total_loss += valid_recon_loss .data [0 ] * len (inputs )
@@ -176,11 +172,7 @@ def fit(self, trainloader, validloader, lr=0.001, num_epochs=10, corrupt=0.3,
176
172
inputs = Variable (inputs )
177
173
inputs_corr = Variable (inputs_corr )
178
174
179
- hidden = self .encode (inputs_corr )
180
- if loss_type == "cross-entropy" :
181
- outputs = self .decode (hidden , binary = True )
182
- else :
183
- outputs = self .decode (hidden )
175
+ z , outputs = self .forward (inputs_corr )
184
176
recon_loss = criterion (outputs , inputs )
185
177
train_loss += recon_loss .data [0 ]* len (inputs )
186
178
recon_loss .backward ()
@@ -193,11 +185,7 @@ def fit(self, trainloader, validloader, lr=0.001, num_epochs=10, corrupt=0.3,
193
185
if use_cuda :
194
186
inputs = inputs .cuda ()
195
187
inputs = Variable (inputs )
196
- hidden = self .encode (inputs , train = False )
197
- if loss_type == "cross-entropy" :
198
- outputs = self .decode (hidden , binary = True )
199
- else :
200
- outputs = self .decode (hidden )
188
+ z , outputs = self .forward (inputs )
201
189
202
190
valid_recon_loss = criterion (outputs , inputs )
203
191
valid_loss += valid_recon_loss .data [0 ] * len (inputs )
0 commit comments