@@ -332,7 +332,7 @@ def masked_conv_aff_coupling(input_, mask_in, dim, name,
332
332
residual_blocks = residual_blocks ,
333
333
bottleneck = bottleneck , skip = skip )
334
334
mask = tf .mod (mask_channel + mask , 2 )
335
- res = tf .split (axis = res , num_or_size_splits = 2 , value = 3 )
335
+ res = tf .split (axis = 3 , num_or_size_splits = 2 , value = res )
336
336
shift , log_rescaling = res [- 2 ], res [- 1 ]
337
337
scale = variable_on_cpu (
338
338
"rescaling_scale" , [],
@@ -486,9 +486,9 @@ def conv_ch_aff_coupling(input_, dim, name,
486
486
scope .reuse_variables ()
487
487
488
488
if change_bottom :
489
- input_ , canvas = tf .split (axis = input_ , num_or_size_splits = 2 , value = 3 )
489
+ input_ , canvas = tf .split (axis = 3 , num_or_size_splits = 2 , value = input_ )
490
490
else :
491
- canvas , input_ = tf .split (axis = input_ , num_or_size_splits = 2 , value = 3 )
491
+ canvas , input_ = tf .split (axis = 3 , num_or_size_splits = 2 , value = input_ )
492
492
shape = input_ .get_shape ().as_list ()
493
493
batch_size = shape [0 ]
494
494
height = shape [1 ]
@@ -509,7 +509,7 @@ def conv_ch_aff_coupling(input_, dim, name,
509
509
train = train , weight_norm = weight_norm ,
510
510
residual_blocks = residual_blocks ,
511
511
bottleneck = bottleneck , skip = skip )
512
- shift , log_rescaling = tf .split (axis = res , num_or_size_splits = 2 , value = 3 )
512
+ shift , log_rescaling = tf .split (axis = 3 , num_or_size_splits = 2 , value = res )
513
513
scale = variable_on_cpu (
514
514
"scale" , [],
515
515
tf .constant_initializer (1. ))
@@ -570,9 +570,9 @@ def conv_ch_add_coupling(input_, dim, name,
570
570
scope .reuse_variables ()
571
571
572
572
if change_bottom :
573
- input_ , canvas = tf .split (axis = input_ , num_or_size_splits = 2 , value = 3 )
573
+ input_ , canvas = tf .split (axis = 3 , num_or_size_splits = 2 , value = input_ )
574
574
else :
575
- canvas , input_ = tf .split (axis = input_ , num_or_size_splits = 2 , value = 3 )
575
+ canvas , input_ = tf .split (axis = 3 , num_or_size_splits = 2 , value = input_ )
576
576
shape = input_ .get_shape ().as_list ()
577
577
channels = shape [3 ]
578
578
res = input_
@@ -736,8 +736,8 @@ def rec_masked_conv_coupling(input_, hps, scale_idx, n_scale,
736
736
log_diff_1 = log_diff [:, :, :, :channels ]
737
737
log_diff_2 = log_diff [:, :, :, channels :]
738
738
else :
739
- res_1 , res_2 = tf .split (axis = res , num_or_size_splits = 2 , value = 3 )
740
- log_diff_1 , log_diff_2 = tf .split (axis = log_diff , num_or_size_splits = 2 , value = 3 )
739
+ res_1 , res_2 = tf .split (axis = 3 , num_or_size_splits = 2 , value = res )
740
+ log_diff_1 , log_diff_2 = tf .split (axis = 3 , num_or_size_splits = 2 , value = log_diff )
741
741
res_1 , inc_log_diff = rec_masked_conv_coupling (
742
742
input_ = res_1 , hps = hps , scale_idx = scale_idx + 1 , n_scale = n_scale ,
743
743
use_batch_norm = use_batch_norm , weight_norm = weight_norm ,
@@ -798,8 +798,8 @@ def rec_masked_deconv_coupling(input_, hps, scale_idx, n_scale,
798
798
log_diff_1 = log_diff [:, :, :, :channels ]
799
799
log_diff_2 = log_diff [:, :, :, channels :]
800
800
else :
801
- res_1 , res_2 = tf .split (axis = res , num_or_size_splits = 2 , value = 3 )
802
- log_diff_1 , log_diff_2 = tf .split (axis = log_diff , num_or_size_splits = 2 , value = 3 )
801
+ res_1 , res_2 = tf .split (axis = 3 , num_or_size_splits = 2 , value = res )
802
+ log_diff_1 , log_diff_2 = tf .split (axis = 3 , num_or_size_splits = 2 , value = log_diff )
803
803
res_1 , log_diff_1 = rec_masked_deconv_coupling (
804
804
input_ = res_1 , hps = hps ,
805
805
scale_idx = scale_idx + 1 , n_scale = n_scale ,
@@ -1305,7 +1305,7 @@ def __init__(self, hps, sampling=False):
1305
1305
z_lost = z_complete
1306
1306
for scale_idx in xrange (hps .n_scale - 1 ):
1307
1307
z_lost = squeeze_2x2_ordered (z_lost )
1308
- z_lost , _ = tf .split (axis = z_lost , num_or_size_splits = 2 , value = 3 )
1308
+ z_lost , _ = tf .split (axis = 3 , num_or_size_splits = 2 , value = z_lost )
1309
1309
z_compressed = z_lost
1310
1310
z_noisy = z_lost
1311
1311
for _ in xrange (scale_idx + 1 ):
0 commit comments