@@ -142,7 +142,7 @@ int main(int argc, const char* argv[]) {
142142 torch::Tensor real_images = batch.data .to (device);
143143 torch::Tensor real_labels =
144144 torch::empty (batch.data .size (0 ), device).uniform_ (0.8 , 1.0 );
145- torch::Tensor real_output = discriminator->forward (real_images);
145+ torch::Tensor real_output = discriminator->forward (real_images). reshape (real_labels. sizes ()) ;
146146 torch::Tensor d_loss_real =
147147 torch::binary_cross_entropy (real_output, real_labels);
148148 d_loss_real.backward ();
@@ -152,7 +152,7 @@ int main(int argc, const char* argv[]) {
152152 torch::randn ({batch.data .size (0 ), kNoiseSize , 1 , 1 }, device);
153153 torch::Tensor fake_images = generator->forward (noise);
154154 torch::Tensor fake_labels = torch::zeros (batch.data .size (0 ), device);
155- torch::Tensor fake_output = discriminator->forward (fake_images.detach ());
155+ torch::Tensor fake_output = discriminator->forward (fake_images.detach ()). reshape (fake_labels. sizes ()) ;
156156 torch::Tensor d_loss_fake =
157157 torch::binary_cross_entropy (fake_output, fake_labels);
158158 d_loss_fake.backward ();
@@ -163,7 +163,7 @@ int main(int argc, const char* argv[]) {
163163 // Train generator.
164164 generator->zero_grad ();
165165 fake_labels.fill_ (1 );
166- fake_output = discriminator->forward (fake_images);
166+ fake_output = discriminator->forward (fake_images). reshape (fake_labels. sizes ()) ;
167167 torch::Tensor g_loss =
168168 torch::binary_cross_entropy (fake_output, fake_labels);
169169 g_loss.backward ();
0 commit comments