@@ -421,6 +421,7 @@ def conv(name, x, output_channels, filter_size=None, stride=None,
421
421
422
422
x_shape = common_layers .shape_list (x )
423
423
is_2d = len (x_shape ) == 4
424
+ num_steps = x_shape [1 ]
424
425
425
426
# set filter_size, stride and in_channels
426
427
if is_2d :
@@ -435,7 +436,10 @@ def conv(name, x, output_channels, filter_size=None, stride=None,
435
436
conv_filter = tf .nn .conv2d
436
437
else :
437
438
if filter_size is None :
438
- filter_size = [2 , 3 , 3 ]
439
+ if num_steps == 1 :
440
+ filter_size = [1 , 3 , 3 ]
441
+ else :
442
+ filter_size = [2 , 3 , 3 ]
439
443
if stride is None :
440
444
stride = [1 , 1 , 1 ]
441
445
if dilations is None :
@@ -489,11 +493,17 @@ def conv_block(name, x, mid_channels, dilations=None, activation="relu",
489
493
490
494
x_shape = common_layers .shape_list (x )
491
495
is_2d = len (x_shape ) == 4
496
+ num_steps = x_shape [1 ]
492
497
if is_2d :
493
498
first_filter = [3 , 3 ]
494
499
second_filter = [1 , 1 ]
495
500
else :
496
- first_filter = [2 , 3 , 3 ]
501
+ # special case when number of steps equal 1 to avoid
502
+ # padding.
503
+ if num_steps == 1 :
504
+ first_filter = [1 , 3 , 3 ]
505
+ else :
506
+ first_filter = [2 , 3 , 3 ]
497
507
second_filter = [1 , 1 , 1 ]
498
508
499
509
# Edge Padding + conv2d + actnorm + relu:
@@ -1025,7 +1035,7 @@ def split(name, x, reverse=False, eps=None, eps_std=None, cond_latents=None,
1025
1035
eps_std: Sample x2 with the provided eps_std.
1026
1036
cond_latents: optionally condition x2 on cond_latents.
1027
1037
hparams: next_frame_glow hparams.
1028
- state: tf.nn.rnn_cell.LSTMStateTuple. Current state of the LSTM over z_2.
1038
+ state: tf.nn.rnn_cell.LSTMStateTuple.. Current state of the LSTM over z_2.
1029
1039
Used only when hparams.latent_dist_encoder == "conv_lstm"
1030
1040
condition: bool, Whether or not to condition the distribution on
1031
1041
cond_latents.
0 commit comments