Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 4d7fea0

Browse files
MechCoderCopybara-Service
authored and
Copybara-Service
committed
Set filter_size to 1 when number of time-steps equal 1 when using 3-D convolutions. This avoids unnecessary padding across time for this special case.
PiperOrigin-RevId: 228946248
1 parent 3b34470 commit 4d7fea0

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

tensor2tensor/models/research/glow_ops.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ def conv(name, x, output_channels, filter_size=None, stride=None,
421421

422422
x_shape = common_layers.shape_list(x)
423423
is_2d = len(x_shape) == 4
424+
num_steps = x_shape[1]
424425

425426
# set filter_size, stride and in_channels
426427
if is_2d:
@@ -435,7 +436,10 @@ def conv(name, x, output_channels, filter_size=None, stride=None,
435436
conv_filter = tf.nn.conv2d
436437
else:
437438
if filter_size is None:
438-
filter_size = [2, 3, 3]
439+
if num_steps == 1:
440+
filter_size = [1, 3, 3]
441+
else:
442+
filter_size = [2, 3, 3]
439443
if stride is None:
440444
stride = [1, 1, 1]
441445
if dilations is None:
@@ -489,11 +493,17 @@ def conv_block(name, x, mid_channels, dilations=None, activation="relu",
489493

490494
x_shape = common_layers.shape_list(x)
491495
is_2d = len(x_shape) == 4
496+
num_steps = x_shape[1]
492497
if is_2d:
493498
first_filter = [3, 3]
494499
second_filter = [1, 1]
495500
else:
496-
first_filter = [2, 3, 3]
501+
# special case when number of steps equal 1 to avoid
502+
# padding.
503+
if num_steps == 1:
504+
first_filter = [1, 3, 3]
505+
else:
506+
first_filter = [2, 3, 3]
497507
second_filter = [1, 1, 1]
498508

499509
# Edge Padding + conv2d + actnorm + relu:
@@ -1025,7 +1035,7 @@ def split(name, x, reverse=False, eps=None, eps_std=None, cond_latents=None,
10251035
eps_std: Sample x2 with the provided eps_std.
10261036
cond_latents: optionally condition x2 on cond_latents.
10271037
hparams: next_frame_glow hparams.
1028-
state: tf.nn.rnn_cell.LSTMStateTuple. Current state of the LSTM over z_2.
1038+
state: tf.nn.rnn_cell.LSTMStateTuple.. Current state of the LSTM over z_2.
10291039
Used only when hparams.latent_dist_encoder == "conv_lstm"
10301040
condition: bool, Whether or not to condition the distribution on
10311041
cond_latents.

tensor2tensor/models/research/glow_ops_test.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -430,16 +430,18 @@ def test_actnorm_3d(self):
430430
("dil_gatu", True, "gatu"), ("no_dil_gatu", False, "gatu"),
431431
("dil_relu_drop", True, "relu", 0.1),
432432
("dil_gatu_drop", True, "gatu", 0.1),
433-
("dil_gatu_drop_noise", True, "gatu", 0.1, 0.1))
433+
("dil_gatu_drop_noise", True, "gatu", 0.1, 0.1),
434+
("gatu_drop_single_step", False, "gatu", 0.1, 0.1, 1),
435+
("dil_gatu_drop_single_step", True, "gatu", 0.1, 0.1, 1),)
434436
def test_temporal_latent_to_dist(self, apply_dilation, activation,
435-
dropout=0.0, noise=0.1):
437+
dropout=0.0, noise=0.1, num_steps=5):
436438
with tf.Graph().as_default():
437439
hparams = self.get_glow_hparams()
438440
hparams.latent_apply_dilations = apply_dilation
439441
hparams.latent_activation = activation
440442
hparams.latent_dropout = dropout
441443
hparams.latent_noise = noise
442-
latent_shape = (16, 5, 32, 32, 48)
444+
latent_shape = (16, num_steps, 32, 32, 48)
443445
latents = tf.random_normal(latent_shape)
444446
dist = glow_ops.temporal_latent_to_dist(
445447
"tensor_to_dist", latents, hparams)

0 commit comments

Comments
 (0)