Skip to content

Commit bdb948c

Browse files
authored
Fix the DCGAN C++ shape warning (#1207)
fix the dcgan shape warning
1 parent 30b310a commit bdb948c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

cpp/dcgan/dcgan.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ int main(int argc, const char* argv[]) {
142142
torch::Tensor real_images = batch.data.to(device);
143143
torch::Tensor real_labels =
144144
torch::empty(batch.data.size(0), device).uniform_(0.8, 1.0);
145-
torch::Tensor real_output = discriminator->forward(real_images);
145+
torch::Tensor real_output = discriminator->forward(real_images).reshape(real_labels.sizes());
146146
torch::Tensor d_loss_real =
147147
torch::binary_cross_entropy(real_output, real_labels);
148148
d_loss_real.backward();
@@ -152,7 +152,7 @@ int main(int argc, const char* argv[]) {
152152
torch::randn({batch.data.size(0), kNoiseSize, 1, 1}, device);
153153
torch::Tensor fake_images = generator->forward(noise);
154154
torch::Tensor fake_labels = torch::zeros(batch.data.size(0), device);
155-
torch::Tensor fake_output = discriminator->forward(fake_images.detach());
155+
torch::Tensor fake_output = discriminator->forward(fake_images.detach()).reshape(fake_labels.sizes());
156156
torch::Tensor d_loss_fake =
157157
torch::binary_cross_entropy(fake_output, fake_labels);
158158
d_loss_fake.backward();
@@ -163,7 +163,7 @@ int main(int argc, const char* argv[]) {
163163
// Train generator.
164164
generator->zero_grad();
165165
fake_labels.fill_(1);
166-
fake_output = discriminator->forward(fake_images);
166+
fake_output = discriminator->forward(fake_images).reshape(fake_labels.sizes());
167167
torch::Tensor g_loss =
168168
torch::binary_cross_entropy(fake_output, fake_labels);
169169
g_loss.backward();

0 commit comments

Comments
 (0)