Skip to content

Commit 7f03752

Browse files
Fix memory issues (#6)
Store intermediate arrays automatically. Do not keep heavy data loaded. Fetch it at training time instead.
1 parent e97be86 commit 7f03752

File tree

10 files changed

+537
-244
lines changed

10 files changed

+537
-244
lines changed

core_analysis/__main__.py

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,14 @@
66

77
from core_analysis.dataset import Dataset
88
from core_analysis.architecture import Model
9-
from core_analysis.utils.visualize import plot_loss, plot_predictions, plot_test_results
10-
from core_analysis.utils.constants import MODEL_FILENAME, LABELS_PATH
9+
from core_analysis.utils.visualize import (
10+
turn_plot_off,
11+
Figure,
12+
Image,
13+
Mask,
14+
Loss,
15+
)
16+
from core_analysis.utils.constants import MODEL_FILENAME, LABELS_PATH, TODAY
1117

1218

1319
# Check the number of available GPUs.
@@ -23,25 +29,59 @@
2329
parser.add_argument("--test", action="store_true")
2430
parser.add_argument("-p", "--plot", action="store_true")
2531
parser.add_argument("-w", "--weights-filename", default=MODEL_FILENAME)
26-
parser.add_argument("-a", "--do_augment", action="store_true")
32+
parser.add_argument("-a", "--do-augment", action="store_true")
33+
parser.add_argument("-e", "--run-eagerly", action="store_true")
2734

2835

2936
def main(args):
30-
model = Model()
37+
if not args.plot:
38+
turn_plot_off()
39+
40+
model = Model(args.weights_filename, args.run_eagerly)
3141
dataset = Dataset(LABELS_PATH)
3242

3343
if args.train:
34-
history = model.train(
35-
dataset.subset("train"), dataset.subset("val"), args.weights_filename
44+
train_subset = dataset.subset("train")
45+
val_subset = dataset.subset("val")
46+
47+
image = next(iter(train_subset.imgs.values()))
48+
Figure(
49+
filename="image_masks",
50+
subplots=[
51+
Image(image, draw_boxes=True),
52+
*(Mask(image.masks[..., i]) for i in range(3)),
53+
],
3654
)
37-
if args.plot:
38-
plot_loss(history)
39-
plot_predictions(model, dataset.subset("val"), begin=600, end=610)
55+
Figure(subplots=[Image(image=image, mask=image.masks[..., 1], draw_boxes=True)])
56+
patches, masks = next(iter(train_subset))
57+
Figure(
58+
filename="tiles",
59+
subplots=[Image(patches[0]), *(Mask(masks[0, ..., i]) for i in range(3))],
60+
)
61+
62+
history = model.train(train_subset, val_subset)
63+
64+
Figure(filename=f"graph_losses_{TODAY}", subplots=[Loss(history)])
4065

4166
if args.test:
4267
results = model.test(dataset.subset("test"))
43-
if args.plot:
44-
plot_test_results(results)
68+
69+
image = next(iter(dataset.subset("test").imgs.values()))
70+
pred = model.predict([image])
71+
Figure(
72+
filename="predictions",
73+
subplots=[
74+
Image(image),
75+
Mask(image.masks[..., 1]),
76+
*(Mask(pred[..., i]) for i in range(3)),
77+
],
78+
)
79+
Figure(
80+
filename="predictions_with_images",
81+
subplots=[
82+
Image(image.without_background(), mask=pred[..., i]) for i in range(3)
83+
],
84+
)
4585

4686

4787
if __name__ == "__main__":

core_analysis/architecture.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77

88
import numpy as np
99
import tensorflow as tf
10-
from keras import callbacks
11-
from keras import backend as K
10+
from keras import callback
1211
import segmentation_models as sm
1312

1413
from core_analysis.postprocess import predict_tiles
@@ -20,12 +19,14 @@
2019
TODAY,
2120
)
2221

22+
tf.sum = tf.reduce_sum
23+
2324

2425
class Model:
2526
BACKBONE = "efficientnetb7"
26-
BATCH_SIZE = 16
27+
EPOCHS = 100
2728

28-
def __init__(self, weights_filename=None):
29+
def __init__(self, weights_filename=None, run_eagerly=False):
2930
if weights_filename is not None:
3031
self.model = tf.keras.models.load_model(
3132
join(MODEL_DIR, weights_filename),
@@ -46,9 +47,10 @@ def __init__(self, weights_filename=None):
4647
optimizer=optimizer,
4748
loss=loss.contrastive_loss,
4849
metrics=["acc"],
50+
run_eagerly=run_eagerly,
4951
)
5052

51-
def train(self, train_iterator, val_iterator):
53+
def train(self, train_dataset, val_dataset):
5254
checkpoint_filename = f"linknet_{self.BACKBONE}_weights_{TODAY}.h5"
5355
checkpointer = callbacks.ModelCheckpoint(
5456
filepath=join(MODEL_DIR, checkpoint_filename),
@@ -62,13 +64,16 @@ def train(self, train_iterator, val_iterator):
6264
min_delta=10e-4,
6365
patience=50,
6466
)
67+
batch_size = train_dataset.BATCH_SIZE
68+
steps_per_epoch = self.N_PATCHES // batch_size
69+
val_steps_per_epoch = steps_per_epoch // 50
6570
history = self.model.fit(
66-
X_train,
67-
Y_train,
68-
batch_size=self.BATCH_SIZE,
69-
validation_data=(X_test, Y_test),
71+
iter(train_dataset),
72+
validation_data=iter(val_dataset),
73+
epochs=self.EPOCHS,
74+
steps_per_epoch=steps_per_epoch,
75+
validation_steps=val_steps_per_epoch,
7076
callbacks=[checkpointer, early_stopping],
71-
epochs=250,
7277
)
7378
return history
7479

@@ -90,7 +95,7 @@ def __init__(self, dim, ths, wmatrix=0, use_weights=False, hold_out=0.1):
9095

9196
def masked_rmse(self, y_true, y_pred):
9297
# Distance between the predictions and simulation probabilities.
93-
squared_diff = K.square(y_true - y_pred)
98+
squared_diff = (y_true - y_pred) ** 2
9499

95100
# Give different weights by class.
96101
if self.use_weights:
@@ -101,31 +106,25 @@ def masked_rmse(self, y_true, y_pred):
101106

102107
# Take some of the training points out at random.
103108
if self.hold_out > 0:
104-
mask *= tf.where(
105-
tf.random.uniform(
106-
shape=(1, *squared_diff.shape[1:]), minval=0.0, maxval=1.0
107-
)
108-
> self.hold_out,
109-
1.0,
110-
0.0,
109+
random = tf.random.uniform(
110+
shape=[1, *DIM[:2], N_CLASSES], minval=0.0, maxval=1.0
111111
)
112+
mask *= tf.where(random > self.hold_out, 1.0, 0.0)
112113

113-
denominator = K.sum(mask) # Number of pixels.
114+
denominator = tf.sum(mask) # Number of pixels.
114115
if self.use_weights:
115-
denominator = K.sum(mask * self.wmatrix)
116+
denominator = tf.sum(mask * self.wmatrix)
116117

117-
# Sum of squared differences at sampled locations,
118-
summ = K.sum(squared_diff * mask)
119-
# Compute error,
120-
rmse = K.sqrt(summ / denominator)
118+
# Compute error.
119+
rmse = tf.sqrt(tf.sum(squared_diff * mask) / denominator)
121120

122121
return rmse
123122

124123
def dice_loss(self, y_true, y_pred):
125124
# Dice coefficient loss.
126125
y_pred = tf.cast(y_pred > 0.5, dtype=tf.float32)
127-
intersection = K.sum(y_true * y_pred, axis=[1, 2, 3])
128-
union = K.sum(y_true + y_pred, axis=[1, 2, 3]) - intersection
126+
intersection = tf.sum(y_true * y_pred, axis=[1, 2, 3])
127+
union = tf.sum(y_true + y_pred, axis=[1, 2, 3]) - intersection
129128
dice_loss = 1.0 - (2.0 * intersection + 1.0) / (union + 1.0)
130129

131130
return dice_loss

0 commit comments

Comments
 (0)