forked from apple1986/SAMatch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathval_3D.py
More file actions
107 lines (95 loc) · 3.98 KB
/
Copy pathval_3D.py
File metadata and controls
107 lines (95 loc) · 3.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import math
from glob import glob
import h5py
# import nibabel as nib
import numpy as np
import SimpleITK as sitk
import torch
import torch.nn.functional as F
from medpy import metric
from tqdm import tqdm
def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1):
w, h, d = image.shape
# if the size of image is less than patch_size, then padding it
add_pad = False
if w < patch_size[0]:
w_pad = patch_size[0]-w
add_pad = True
else:
w_pad = 0
if h < patch_size[1]:
h_pad = patch_size[1]-h
add_pad = True
else:
h_pad = 0
if d < patch_size[2]:
d_pad = patch_size[2]-d
add_pad = True
else:
d_pad = 0
wl_pad, wr_pad = w_pad//2, w_pad-w_pad//2
hl_pad, hr_pad = h_pad//2, h_pad-h_pad//2
dl_pad, dr_pad = d_pad//2, d_pad-d_pad//2
if add_pad:
image = np.pad(image, [(wl_pad, wr_pad), (hl_pad, hr_pad),
(dl_pad, dr_pad)], mode='constant', constant_values=0)
ww, hh, dd = image.shape
sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1
sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1
sz = math.ceil((dd - patch_size[2]) / stride_z) + 1
# print("{}, {}, {}".format(sx, sy, sz))
score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32)
cnt = np.zeros(image.shape).astype(np.float32)
for x in range(0, sx):
xs = min(stride_xy*x, ww-patch_size[0])
for y in range(0, sy):
ys = min(stride_xy * y, hh-patch_size[1])
for z in range(0, sz):
zs = min(stride_z * z, dd-patch_size[2])
test_patch = image[xs:xs+patch_size[0],
ys:ys+patch_size[1], zs:zs+patch_size[2]]
test_patch = np.expand_dims(np.expand_dims(
test_patch, axis=0), axis=0).astype(np.float32)
test_patch = torch.from_numpy(test_patch).cuda()
with torch.no_grad():
y1 = net(test_patch)
# ensemble
y = torch.softmax(y1, dim=1)
y = y.cpu().data.numpy()
y = y[0, :, :, :, :]
score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
= score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y
cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
= cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1
score_map = score_map/np.expand_dims(cnt, axis=0)
label_map = np.argmax(score_map, axis=0)
if add_pad:
label_map = label_map[wl_pad:wl_pad+w,
hl_pad:hl_pad+h, dl_pad:dl_pad+d]
score_map = score_map[:, wl_pad:wl_pad +
w, hl_pad:hl_pad+h, dl_pad:dl_pad+d]
return label_map
def cal_metric(gt, pred):
if pred.sum() > 0 and gt.sum() > 0:
dice = metric.binary.dc(pred, gt)
hd95 = metric.binary.hd95(pred, gt)
return np.array([dice, hd95])
else:
return np.zeros(2)
def test_all_case(net, base_dir, test_list="full_test.list", num_classes=4, patch_size=(48, 160, 160), stride_xy=32, stride_z=24):
with open(base_dir + '/{}'.format(test_list), 'r') as f:
image_list = f.readlines()
image_list = [base_dir + "/data/{}.h5".format(
item.replace('\n', '').split(",")[0]) for item in image_list]
total_metric = np.zeros((num_classes-1, 2))
print("Validation begin")
for image_path in tqdm(image_list):
h5f = h5py.File(image_path, 'r')
image = h5f['image'][:]
label = h5f['label'][:]
prediction = test_single_case(
net, image, stride_xy, stride_z, patch_size, num_classes=num_classes)
for i in range(1, num_classes):
total_metric[i-1, :] += cal_metric(label == i, prediction == i)
print("Validation end")
return total_metric / len(image_list)