Skip to content

Commit 5c53534

Browse files
committed
Manually fixed many occurrences of tf.split
1 parent fdc0c4a commit 5c53534

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

neural_gpu/neural_gpu.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def reorder_beam(beam_size, batch_size, beam_val, output, is_first,
211211
# beam_val is [batch_size x beam_size]; let b = batch_size * beam_size
212212
# decided is len x b x a x b
213213
# output is b x out_size; step is b x len x a x b;
214-
outputs = tf.split(axis=tf.nn.log_softmax(output), num_or_size_splits=beam_size, value=0)
214+
outputs = tf.split(axis=0, num_or_size_splits=beam_size, value=tf.nn.log_softmax(output))
215215
all_beam_vals, all_beam_idx = [], []
216216
beam_range = 1 if is_first else beam_size
217217
for i in xrange(beam_range):
@@ -266,9 +266,9 @@ def __init__(self, nmaps, vec_size, niclass, noclass, dropout,
266266
self.input = tf.placeholder(tf.int32, name="inp")
267267
self.target = tf.placeholder(tf.int32, name="tgt")
268268
self.prev_step = tf.placeholder(tf.float32, name="prev_step")
269-
gpu_input = tf.split(axis=self.input, num_or_size_splits=num_gpus, value=0)
270-
gpu_target = tf.split(axis=self.target, num_or_size_splits=num_gpus, value=0)
271-
gpu_prev_step = tf.split(axis=self.prev_step, num_or_size_splits=num_gpus, value=0)
269+
gpu_input = tf.split(axis=0, num_or_size_splits=num_gpus, value=self.input)
270+
gpu_target = tf.split(axis=0, num_or_size_splits=num_gpus, value=self.target)
271+
gpu_prev_step = tf.split(axis=0, num_or_size_splits=num_gpus, value=self.prev_step)
272272
batch_size = tf.shape(gpu_input[0])[0]
273273

274274
if backward:

real_nvp/real_nvp_multiscale_dataset.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def masked_conv_aff_coupling(input_, mask_in, dim, name,
332332
residual_blocks=residual_blocks,
333333
bottleneck=bottleneck, skip=skip)
334334
mask = tf.mod(mask_channel + mask, 2)
335-
res = tf.split(axis=res, num_or_size_splits=2, value=3)
335+
res = tf.split(axis=3, num_or_size_splits=2, value=res)
336336
shift, log_rescaling = res[-2], res[-1]
337337
scale = variable_on_cpu(
338338
"rescaling_scale", [],
@@ -486,9 +486,9 @@ def conv_ch_aff_coupling(input_, dim, name,
486486
scope.reuse_variables()
487487

488488
if change_bottom:
489-
input_, canvas = tf.split(axis=input_, num_or_size_splits=2, value=3)
489+
input_, canvas = tf.split(axis=3, num_or_size_splits=2, value=input_)
490490
else:
491-
canvas, input_ = tf.split(axis=input_, num_or_size_splits=2, value=3)
491+
canvas, input_ = tf.split(axis=3, num_or_size_splits=2, value=input_)
492492
shape = input_.get_shape().as_list()
493493
batch_size = shape[0]
494494
height = shape[1]
@@ -509,7 +509,7 @@ def conv_ch_aff_coupling(input_, dim, name,
509509
train=train, weight_norm=weight_norm,
510510
residual_blocks=residual_blocks,
511511
bottleneck=bottleneck, skip=skip)
512-
shift, log_rescaling = tf.split(axis=res, num_or_size_splits=2, value=3)
512+
shift, log_rescaling = tf.split(axis=3, num_or_size_splits=2, value=res)
513513
scale = variable_on_cpu(
514514
"scale", [],
515515
tf.constant_initializer(1.))
@@ -570,9 +570,9 @@ def conv_ch_add_coupling(input_, dim, name,
570570
scope.reuse_variables()
571571

572572
if change_bottom:
573-
input_, canvas = tf.split(axis=input_, num_or_size_splits=2, value=3)
573+
input_, canvas = tf.split(axis=3, num_or_size_splits=2, value=input_)
574574
else:
575-
canvas, input_ = tf.split(axis=input_, num_or_size_splits=2, value=3)
575+
canvas, input_ = tf.split(axis=3, num_or_size_splits=2, value=input_)
576576
shape = input_.get_shape().as_list()
577577
channels = shape[3]
578578
res = input_
@@ -736,8 +736,8 @@ def rec_masked_conv_coupling(input_, hps, scale_idx, n_scale,
736736
log_diff_1 = log_diff[:, :, :, :channels]
737737
log_diff_2 = log_diff[:, :, :, channels:]
738738
else:
739-
res_1, res_2 = tf.split(axis=res, num_or_size_splits=2, value=3)
740-
log_diff_1, log_diff_2 = tf.split(axis=log_diff, num_or_size_splits=2, value=3)
739+
res_1, res_2 = tf.split(axis=3, num_or_size_splits=2, value=res)
740+
log_diff_1, log_diff_2 = tf.split(axis=3, num_or_size_splits=2, value=log_diff)
741741
res_1, inc_log_diff = rec_masked_conv_coupling(
742742
input_=res_1, hps=hps, scale_idx=scale_idx + 1, n_scale=n_scale,
743743
use_batch_norm=use_batch_norm, weight_norm=weight_norm,
@@ -798,8 +798,8 @@ def rec_masked_deconv_coupling(input_, hps, scale_idx, n_scale,
798798
log_diff_1 = log_diff[:, :, :, :channels]
799799
log_diff_2 = log_diff[:, :, :, channels:]
800800
else:
801-
res_1, res_2 = tf.split(axis=res, num_or_size_splits=2, value=3)
802-
log_diff_1, log_diff_2 = tf.split(axis=log_diff, num_or_size_splits=2, value=3)
801+
res_1, res_2 = tf.split(axis=3, num_or_size_splits=2, value=res)
802+
log_diff_1, log_diff_2 = tf.split(axis=3, num_or_size_splits=2, value=log_diff)
803803
res_1, log_diff_1 = rec_masked_deconv_coupling(
804804
input_=res_1, hps=hps,
805805
scale_idx=scale_idx + 1, n_scale=n_scale,
@@ -1305,7 +1305,7 @@ def __init__(self, hps, sampling=False):
13051305
z_lost = z_complete
13061306
for scale_idx in xrange(hps.n_scale - 1):
13071307
z_lost = squeeze_2x2_ordered(z_lost)
1308-
z_lost, _ = tf.split(axis=z_lost, num_or_size_splits=2, value=3)
1308+
z_lost, _ = tf.split(axis=3, num_or_size_splits=2, value=z_lost)
13091309
z_compressed = z_lost
13101310
z_noisy = z_lost
13111311
for _ in xrange(scale_idx + 1):

0 commit comments

Comments
 (0)