7
7
8
8
import numpy as np
9
9
import tensorflow as tf
10
- from keras import callbacks
11
- from keras import backend as K
10
+ from keras import callback
12
11
import segmentation_models as sm
13
12
14
13
from core_analysis .postprocess import predict_tiles
20
19
TODAY ,
21
20
)
22
21
22
+ tf .sum = tf .reduce_sum
23
+
23
24
24
25
class Model :
25
26
BACKBONE = "efficientnetb7"
26
- BATCH_SIZE = 16
27
+ EPOCHS = 100
27
28
28
- def __init__ (self , weights_filename = None ):
29
+ def __init__ (self , weights_filename = None , run_eagerly = False ):
29
30
if weights_filename is not None :
30
31
self .model = tf .keras .models .load_model (
31
32
join (MODEL_DIR , weights_filename ),
@@ -46,9 +47,10 @@ def __init__(self, weights_filename=None):
46
47
optimizer = optimizer ,
47
48
loss = loss .contrastive_loss ,
48
49
metrics = ["acc" ],
50
+ run_eagerly = run_eagerly ,
49
51
)
50
52
51
- def train (self , train_iterator , val_iterator ):
53
+ def train (self , train_dataset , val_dataset ):
52
54
checkpoint_filename = f"linknet_{ self .BACKBONE } _weights_{ TODAY } .h5"
53
55
checkpointer = callbacks .ModelCheckpoint (
54
56
filepath = join (MODEL_DIR , checkpoint_filename ),
@@ -62,13 +64,16 @@ def train(self, train_iterator, val_iterator):
62
64
min_delta = 10e-4 ,
63
65
patience = 50 ,
64
66
)
67
+ batch_size = train_dataset .BATCH_SIZE
68
+ steps_per_epoch = self .N_PATCHES // batch_size
69
+ val_steps_per_epoch = steps_per_epoch // 50
65
70
history = self .model .fit (
66
- X_train ,
67
- Y_train ,
68
- batch_size = self .BATCH_SIZE ,
69
- validation_data = (X_test , Y_test ),
71
+ iter (train_dataset ),
72
+ validation_data = iter (val_dataset ),
73
+ epochs = self .EPOCHS ,
74
+ steps_per_epoch = steps_per_epoch ,
75
+ validation_steps = val_steps_per_epoch ,
70
76
callbacks = [checkpointer , early_stopping ],
71
- epochs = 250 ,
72
77
)
73
78
return history
74
79
@@ -90,7 +95,7 @@ def __init__(self, dim, ths, wmatrix=0, use_weights=False, hold_out=0.1):
90
95
91
96
def masked_rmse (self , y_true , y_pred ):
92
97
# Distance between the predictions and simulation probabilities.
93
- squared_diff = K . square (y_true - y_pred )
98
+ squared_diff = (y_true - y_pred ) ** 2
94
99
95
100
# Give different weights by class.
96
101
if self .use_weights :
@@ -101,31 +106,25 @@ def masked_rmse(self, y_true, y_pred):
101
106
102
107
# Take some of the training points out at random.
103
108
if self .hold_out > 0 :
104
- mask *= tf .where (
105
- tf .random .uniform (
106
- shape = (1 , * squared_diff .shape [1 :]), minval = 0.0 , maxval = 1.0
107
- )
108
- > self .hold_out ,
109
- 1.0 ,
110
- 0.0 ,
109
+ random = tf .random .uniform (
110
+ shape = [1 , * DIM [:2 ], N_CLASSES ], minval = 0.0 , maxval = 1.0
111
111
)
112
+ mask *= tf .where (random > self .hold_out , 1.0 , 0.0 )
112
113
113
- denominator = K .sum (mask ) # Number of pixels.
114
+ denominator = tf .sum (mask ) # Number of pixels.
114
115
if self .use_weights :
115
- denominator = K .sum (mask * self .wmatrix )
116
+ denominator = tf .sum (mask * self .wmatrix )
116
117
117
- # Sum of squared differences at sampled locations,
118
- summ = K .sum (squared_diff * mask )
119
- # Compute error,
120
- rmse = K .sqrt (summ / denominator )
118
+ # Compute error.
119
+ rmse = tf .sqrt (tf .sum (squared_diff * mask ) / denominator )
121
120
122
121
return rmse
123
122
124
123
def dice_loss (self , y_true , y_pred ):
125
124
# Dice coefficient loss.
126
125
y_pred = tf .cast (y_pred > 0.5 , dtype = tf .float32 )
127
- intersection = K .sum (y_true * y_pred , axis = [1 , 2 , 3 ])
128
- union = K .sum (y_true + y_pred , axis = [1 , 2 , 3 ]) - intersection
126
+ intersection = tf .sum (y_true * y_pred , axis = [1 , 2 , 3 ])
127
+ union = tf .sum (y_true + y_pred , axis = [1 , 2 , 3 ]) - intersection
129
128
dice_loss = 1.0 - (2.0 * intersection + 1.0 ) / (union + 1.0 )
130
129
131
130
return dice_loss
0 commit comments