@@ -142,7 +142,7 @@ int main(int argc, const char* argv[]) {
142
142
torch::Tensor real_images = batch.data .to (device);
143
143
torch::Tensor real_labels =
144
144
torch::empty (batch.data .size (0 ), device).uniform_ (0.8 , 1.0 );
145
- torch::Tensor real_output = discriminator->forward (real_images);
145
+ torch::Tensor real_output = discriminator->forward (real_images). reshape (real_labels. sizes ()) ;
146
146
torch::Tensor d_loss_real =
147
147
torch::binary_cross_entropy (real_output, real_labels);
148
148
d_loss_real.backward ();
@@ -152,7 +152,7 @@ int main(int argc, const char* argv[]) {
152
152
torch::randn ({batch.data .size (0 ), kNoiseSize , 1 , 1 }, device);
153
153
torch::Tensor fake_images = generator->forward (noise);
154
154
torch::Tensor fake_labels = torch::zeros (batch.data .size (0 ), device);
155
- torch::Tensor fake_output = discriminator->forward (fake_images.detach ());
155
+ torch::Tensor fake_output = discriminator->forward (fake_images.detach ()). reshape (fake_labels. sizes ()) ;
156
156
torch::Tensor d_loss_fake =
157
157
torch::binary_cross_entropy (fake_output, fake_labels);
158
158
d_loss_fake.backward ();
@@ -163,7 +163,7 @@ int main(int argc, const char* argv[]) {
163
163
// Train generator.
164
164
generator->zero_grad ();
165
165
fake_labels.fill_ (1 );
166
- fake_output = discriminator->forward (fake_images);
166
+ fake_output = discriminator->forward (fake_images). reshape (fake_labels. sizes ()) ;
167
167
torch::Tensor g_loss =
168
168
torch::binary_cross_entropy (fake_output, fake_labels);
169
169
g_loss.backward ();
0 commit comments