Skip to content

Commit 833f2b6

Browse files
committed
test1
0 parents  commit 833f2b6

File tree

8 files changed

+568
-0
lines changed

8 files changed

+568
-0
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
echo # voxelGAN
2+
# voxelGAN

config.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import os
2+
from keras import backend as K
3+
4+
K.set_image_data_format('channels_last')
5+
6+
# general
7+
data_dir = 'data/train/**/*t1.nii.gz'
8+
predict_dir = 'data/predict/**/*t1.nii.gz'
9+
checkpoint_dir = 'checkpoints/'
10+
11+
directory = os.path.dirname(checkpoint_dir)
12+
if not os.path.exists(directory):
13+
os.makedirs(directory)
14+
15+
# bdsscgan
16+
input = 4
17+
output = 4
18+
size = 32
19+
epochs = 50
20+
kernel_depth = 32
21+
22+
checkpoint_gen_name = checkpoint_dir + 'gen.hdf5'
23+
checkpoint_disc_name = checkpoint_dir + 'disc.hdf5'

model.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import tensorflow as tf
2+
#keras = tf.keras
3+
from keras.models import Model
4+
from keras.layers import Bidirectional, Input, Concatenate, Cropping3D, Dense, Flatten, TimeDistributed, ConvLSTM2D, LeakyReLU
5+
from keras.layers.core import Dropout, Activation, Reshape
6+
from keras.layers.convolutional import Conv3D, MaxPooling3D, UpSampling3D, Cropping3D
7+
from keras.layers.normalization import BatchNormalization
8+
from keras.layers.merge import concatenate
9+
import keras.backend as K
10+
11+
12+
def conv_layer(layer, depth, size):
13+
conv = Conv3D(depth, size, padding='same')(layer)
14+
conv = LeakyReLU(0.2)(conv)
15+
return BatchNormalization()(conv)
16+
17+
def Generator(input_shape, output, kernel_depth, kernel_size=3):
18+
# 32x32x32x4
19+
input = Input(shape=input_shape)
20+
21+
conv_32 = conv_layer(input, 1 * kernel_depth, kernel_size)
22+
pool_16 = MaxPooling3D()(conv_32)
23+
24+
conv_16 = conv_layer(pool_16, 2 * kernel_depth, kernel_size)
25+
pool_8 = MaxPooling3D()(conv_16)
26+
27+
conv_8 = conv_layer(pool_8, 4 * kernel_depth, kernel_size)
28+
29+
up_16 = concatenate([UpSampling3D()(conv_8), conv_16])
30+
up_conv_16 = conv_layer(up_16, 2 * kernel_depth, kernel_size)
31+
32+
up_32 = concatenate([UpSampling3D()(up_conv_16), conv_32])
33+
up_conv_32 = conv_layer(up_32, 1 * kernel_depth, kernel_size)
34+
35+
final1 = concatenate([up_conv_32, input])
36+
final2 = Conv3D(output, 1, activation='softmax')(final1)
37+
38+
model = Model(input, final2, name="Generator")
39+
return model
40+
41+
def Discriminator(input_shape, generator_shape, kernel_depth, kernel_size=5):
42+
real_input = Input(shape=input_shape)
43+
generator_input = Input(shape=generator_shape)
44+
input = Concatenate()([real_input, generator_input])
45+
46+
conv_32 = conv_layer(input, 1 * kernel_depth, kernel_size)
47+
pool_16 = MaxPooling3D()(conv_32)
48+
49+
conv_16 = conv_layer(pool_16, 2 * kernel_depth, kernel_size)
50+
pool_8 = MaxPooling3D()(conv_16)
51+
52+
conv_8 = conv_layer(pool_8, 4 * kernel_depth, kernel_size)
53+
pool_4 = MaxPooling3D()(conv_8)
54+
55+
x = Flatten()(pool_4)
56+
x = Dense(2, activation="softmax")(x)
57+
58+
model = Model([real_input, generator_input], x, name="Discriminator")
59+
return model
60+
61+
def Combine(gen, disc, input_shape, new_sequence):
62+
input = Input(shape=input_shape)
63+
generated_image = gen(input)
64+
65+
reshaped = Reshape(new_sequence)(generated_image)
66+
67+
DCGAN_output = disc([input, reshaped])
68+
69+
DCGAN = Model(inputs=[input],
70+
outputs=[generated_image, DCGAN_output],
71+
name="Combined")
72+
73+
return DCGAN

predict.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from keras.optimizers import Adadelta
2+
3+
from model import *
4+
from utils import *
5+
from config import *
6+
7+
opt = Adadelta()
8+
9+
gen = Generator((size, size, size, input), output, kernel_depth)
10+
gen.compile(loss='mae', optimizer=opt)
11+
gen.load_weights(checkpoint_gen_name)
12+
13+
# List sequences
14+
sequences = prepare_data(predict_dir)
15+
print(sequences)
16+
17+
progbar = keras.utils.Progbar(len(sequences))
18+
19+
for s in range(len(sequences)):
20+
21+
22+
progbar.add(1)
23+
sequence = sequences[s]
24+
x, idx = load2(sequence, size)
25+
y = []
26+
27+
for i in range(len(x)):
28+
29+
# gen
30+
fake = gen.predict(x[i])
31+
print(fake.shape)
32+
33+
y.append(fake)
34+
35+
store(sequence, y, idx)
36+
37+
#save_image(x[i] / 2 + 0.5, y[i], re_shape(generated_y), prediction_dir + "test{}.png".format(s))

train.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from keras.callbacks import TensorBoard
2+
from keras.optimizers import Adadelta
3+
4+
from model import *
5+
from utils import *
6+
from config import *
7+
8+
# Create optimizers
9+
opt_dcgan = Adadelta()
10+
opt_discriminator = Adadelta()
11+
12+
gen = Generator((size, size, size, input), output, kernel_depth)
13+
gen.compile(loss='mae', optimizer=opt_discriminator)
14+
15+
disc = Discriminator((size, size, size, input), (size, size, size, output), kernel_depth)
16+
disc.trainable = False
17+
18+
combined = Combine(gen, disc, (size, size, size, input), (size, size, size, output))
19+
loss = [selective_crossentropy, 'binary_crossentropy']
20+
loss_weights = [10, 1]
21+
combined.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan)
22+
23+
disc.trainable = True
24+
disc.compile(loss='binary_crossentropy', optimizer=opt_discriminator)
25+
26+
if os.path.isfile(checkpoint_gen_name):
27+
gen.load_weights(checkpoint_gen_name)
28+
if os.path.isfile(checkpoint_disc_name):
29+
disc.load_weights(checkpoint_disc_name)
30+
31+
# List sequences
32+
sequences = prepare_data(data_dir)
33+
print(sequences)
34+
35+
real_y = np.reshape(np.array([0, 1]), (1, 2))
36+
fake_y = np.reshape(np.array([1, 0]), (1, 2))
37+
38+
#log = open("train.log",'w')
39+
40+
tensorlog = TensorBoard(log_dir='./logs', histogram_freq=0, batch_size=1, write_graph=True, write_grads=True, write_images=True)
41+
tensorlog.set_model(gen)
42+
43+
for e in range(epochs):
44+
print("Epoch {}".format(e))
45+
random.shuffle(sequences)
46+
47+
# select a fraction
48+
train_offset = int(len(sequences) * 0.9)
49+
train_sequence = sequences[:train_offset]
50+
51+
progbar = keras.utils.Progbar(len(train_sequence))
52+
53+
for s in range(len(train_sequence)):
54+
55+
progbar.add(1)
56+
sequence = train_sequence[s]
57+
x, y, idx = load(sequence, size)
58+
59+
for i in range(len(x)):
60+
61+
# train disc on real
62+
disc.train_on_batch([x[i], y[i]], real_y)
63+
64+
# gen fake
65+
fake = gen.predict(x[i])
66+
67+
# train disc on fake
68+
disc.train_on_batch([x[i], fake], fake_y)
69+
70+
# train combined
71+
disc.trainable = False
72+
combined.train_on_batch(x[i], [y[i], real_y])
73+
disc.trainable = True
74+
75+
#log.write(str(e) + ", " + str(s) + ", " + str(dr_loss) + ", " + str(df_loss) + ", " + str(g_loss[0]) + ", " + str(g_loss[1]) + ", " + str(opt_dcgan.get_config()["lr"]) + "\n")
76+
77+
# output random result
78+
#val_sequence = sequences[train_offset:]
79+
#generated_y = gen.predict(x[random_index])
80+
#save_image(strip(x[random_index]) / 2 + 0.5, y[random_index], re_shape(generated_y), "validation/e{}_{}.png".format(e, s))
81+
82+
# save weights
83+
gen.save_weights(checkpoint_gen_name, overwrite=True)
84+
disc.save_weights(checkpoint_disc_name, overwrite=True)
85+
86+
tensorlog.on_epoch_end(e)
87+
88+
tensorlog.on_train_end()

0 commit comments

Comments
 (0)