-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdice_loss.py
35 lines (24 loc) · 1.03 KB
/
dice_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# model.py
import unigradicon
input_shape = [1, 1, 175, 175, 175]
"""
image format: concatenate one-hot segmentations with preprocessed images along channel dimension,
to get an image with shape B x (1 + num_segs) x H x W x D
"""
class SegmentationSSD(SimilarityBase):
def __init__(self):
super().__init__(isInterpolated=False)
def __call__(self, image_A, image_B):
assert image_A.shape == image_B.shape, "The shape of image_A and image_B sould be the same."
return torch.mean((image_A[1:] - image_B[1:]) ** 2)
class StripSegmentations(icon.RegistrationModule):
def __init__(self, net):
self.net = net
def forward(self, moving, fixed):
return self.net(moving[:, :1], fixed[:, :1])
def make_network():
multigradicon = unigradicon.get_multigradicon(loss_fn=SegmentationSSD())
multigradicon.regis_net = StripSegmentations(multigradicon.regis_net)
multigradicon.assign_identity_map(multigradicon.identity_map.shape)
multigradicon.to(config.device)
return multigradicon