Skip to content

Commit 94d1ee3

Browse files
committed
clean up SVRG code (remove direct python calls)
1 parent 9d296fd commit 94d1ee3

File tree

2 files changed

+86
-89
lines changed

2 files changed

+86
-89
lines changed

main_SVRG.py

+23-21
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(
5858
verbose: bool = False,
5959
complete_gradient_epochs: None | list[int] = None,
6060
precond_update_epochs: None | list[int] = None,
61+
precond_hessian_factor: float = 16.0,
6162
**kwargs,
6263
):
6364
"""
@@ -79,6 +80,7 @@ def __init__(
7980
self._update = 0
8081
self._step_size = initial_step_size
8182
self._subset_number_list = []
83+
self._precond_hessian_factor = precond_hessian_factor
8284

8385
self._data_sub, self._acq_models, self._obj_funs = partitioner.data_partition(
8486
data.acquired_data,
@@ -98,15 +100,12 @@ def __init__(
98100
for f in self._obj_funs: # add prior evenly to every objective function
99101
f.set_prior(data.prior)
100102

101-
self._subset_adjoint_ones = []
103+
self._adjoint_ones = self.x.get_uniform_copy(0)
102104

103105
for i in range(num_subsets):
104106
if self._verbose:
105107
print(f"Calculating subset {i} sensitivity")
106-
subset_adjoint_ones = self._obj_funs[i].get_subset_sensitivity(0)
107-
self._subset_adjoint_ones.append(subset_adjoint_ones)
108-
109-
self._adjoint_ones = np.sum(self._subset_adjoint_ones)
108+
self._adjoint_ones += self._obj_funs[i].get_subset_sensitivity(0)
110109

111110
self._fov_mask = self.x.get_uniform_copy(0)
112111
tmp = 1.0 * (self._adjoint_ones.as_array() > 0)
@@ -143,6 +142,10 @@ def __init__(
143142
self._python_prior.kappa = data.kappa.as_array()
144143
self._python_prior.scale = data.prior.get_penalisation_factor()
145144

145+
self._precond_filter = STIR.SeparableGaussianImageFilter()
146+
self._precond_filter.set_fwhms([5.0, 5.0, 5.0])
147+
self._precond_filter.set_up(data.OSEM_image)
148+
146149
# calculate the initial preconditioner based on the initial image
147150
self._precond = self.calc_precond(self.x)
148151

@@ -157,22 +160,23 @@ def calc_precond(
157160
self,
158161
x: STIR.ImageData,
159162
delta_rel: float = 1e-6,
160-
prior_diag_factor: float = 16.0,
161163
) -> STIR.ImageData:
162164

163165
# generate a smoothed version of the input image
164166
# to avoid high values, especially in first and last slices
165-
xx = x.get_uniform_copy(0)
166-
sig = 5.0 / (2.35 * np.array(xx.spacing))
167-
xx.fill(gaussian_filter(x.as_array(), sig))
168-
169-
delta = delta_rel * xx.max()
170-
prior_diag_hess = xx.get_uniform_copy(0)
171-
172-
prior_diag_hess.fill(self._python_prior.diag_hessian(xx.as_array()))
173-
174-
precond = (xx + delta) / (
175-
self._adjoint_ones + prior_diag_factor * prior_diag_hess * xx
167+
x_sm = self._precond_filter.process(x)
168+
delta = delta_rel * x_sm.max()
169+
170+
prior_diag_hess = x_sm.get_uniform_copy(0)
171+
prior_diag_hess.fill(self._python_prior.diag_hessian(x_sm.as_array()))
172+
173+
precond = (
174+
self._fov_mask
175+
* (x_sm + delta)
176+
/ (
177+
self._adjoint_ones
178+
+ self._precond_hessian_factor * prior_diag_hess * x_sm
179+
)
176180
)
177181

178182
return precond
@@ -226,12 +230,10 @@ def update(self):
226230
)
227231

228232
### Objective has to be maximized -> "+" for gradient ascent
229-
self.x = self.x + self._step_size * self._precond * self._fov_mask * grad
233+
self.x = self.x + self._step_size * self._precond * grad
230234

231235
# enforce non-negative constraint
232-
tmp = self.x.as_array()
233-
np.clip(tmp, 0, None, out=tmp)
234-
self.x.fill(tmp)
236+
self.x.maximum(0, out=self.x)
235237

236238
self._update += 1
237239

test_petric.py

+63-68
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ def test_petric(
335335
metric_period: int,
336336
complete_gradient_epochs: None | list[int] = None,
337337
precond_update_epochs: None | list[int] = None,
338+
precond_hessian_factor: float = 16.0,
338339
):
339340

340341
# get arguments and values such that we can dump them in the outdir
@@ -346,39 +347,25 @@ def test_petric(
346347
current_datetime = datetime.now()
347348
formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S")
348349

350+
sdir_name = f"{formatted_datetime}_ss_{step_size}_n_{num_iter}_subs_{num_subsets}_phf_{precond_hessian_factor}"
351+
349352
if ds == 0:
350353
srcdir = SRCDIR / "Siemens_mMR_NEMA_IQ"
351-
outdir = (
352-
OUTDIR
353-
/ "mMR_NEMA"
354-
/ f"{formatted_datetime}_ss_{step_size}_n_{num_iter}_subs_{num_subsets}"
355-
)
354+
outdir = OUTDIR / "mMR_NEMA" / sdir_name
356355
metrics = [
357356
MetricsWithTimeout(outdir=outdir, transverse_slice=72, coronal_slice=109)
358357
]
359358
elif ds == 1:
360359
srcdir = SRCDIR / "NeuroLF_Hoffman_Dataset"
361-
outdir = (
362-
OUTDIR
363-
/ "NeuroLF_Hoffman"
364-
/ f"{formatted_datetime}_ss_{step_size}_n_{num_iter}_subs_{num_subsets}"
365-
)
360+
outdir = OUTDIR / "NeuroLF_Hoffman" / sdir_name
366361
metrics = [MetricsWithTimeout(outdir=outdir, transverse_slice=72)]
367362
elif ds == 2:
368363
srcdir = SRCDIR / "Siemens_Vision600_thorax"
369-
outdir = (
370-
OUTDIR
371-
/ "Vision600_thorax"
372-
/ f"{formatted_datetime}_ss_{step_size}_n_{num_iter}_subs_{num_subsets}"
373-
)
364+
outdir = OUTDIR / "Vision600_thorax" / sdir_name
374365
metrics = [MetricsWithTimeout(outdir=outdir)]
375366
elif ds == 3:
376367
srcdir = SRCDIR / "Siemens_mMR_ACR"
377-
outdir = (
378-
OUTDIR
379-
/ "Siemens_mMR_ACR"
380-
/ f"{formatted_datetime}_ss_{step_size}_n_{num_iter}_subs_{num_subsets}"
381-
)
368+
outdir = OUTDIR / "Siemens_mMR_ACR" / sdir_name
382369
metrics = [MetricsWithTimeout(outdir=outdir)]
383370
else:
384371
raise ValueError(f"Unknown data set {ds}")
@@ -446,6 +433,7 @@ def test_petric(
446433
num_subsets=num_subsets,
447434
complete_gradient_epochs=complete_gradient_epochs,
448435
precond_update_epochs=precond_update_epochs,
436+
precond_hessian_factor=precond_hessian_factor,
449437
)
450438
algo.run(num_iter, callbacks=metrics + submission_callbacks)
451439

@@ -489,51 +477,58 @@ def test_petric(
489477
precond_update_epochs=precond_update_epochs,
490478
)
491479
else:
492-
for step_size in [1.0, 1.5]:
493-
# data set 0 "mMR_NEMA_IQ" - num views 252
494-
for num_subsets in [28]:
495-
test_petric(
496-
step_size=step_size,
497-
ds=0,
498-
num_iter=300,
499-
num_subsets=num_subsets,
500-
metric_period=num_subsets,
501-
complete_gradient_epochs=complete_gradient_epochs,
502-
precond_update_epochs=precond_update_epochs,
503-
)
504-
505-
# data set 1 "neuro LF" - num views 128
506-
for num_subsets in [16]:
507-
test_petric(
508-
step_size=step_size,
509-
ds=1,
510-
num_iter=300,
511-
num_subsets=num_subsets,
512-
metric_period=num_subsets,
513-
complete_gradient_epochs=complete_gradient_epochs,
514-
precond_update_epochs=precond_update_epochs,
515-
)
516-
517-
# data set 2 "vision" - num views 50
518-
for num_subsets in [25]:
519-
test_petric(
520-
step_size=step_size,
521-
ds=2,
522-
num_iter=200,
523-
num_subsets=num_subsets,
524-
metric_period=num_subsets,
525-
complete_gradient_epochs=complete_gradient_epochs,
526-
precond_update_epochs=precond_update_epochs,
527-
)
528-
529-
# data set 4 "mMR_ACR" - num views 252
530-
for num_subsets in [28]:
531-
test_petric(
532-
step_size=1.0,
533-
ds=3,
534-
num_iter=3 * 28 + 1,
535-
num_subsets=num_subsets,
536-
metric_period=num_subsets,
537-
complete_gradient_epochs=complete_gradient_epochs,
538-
precond_update_epochs=precond_update_epochs,
539-
)
480+
# for phf in [8.0, 32.0, 4.0]:
481+
for phf in [16]:
482+
for step_size in [1.0]:
483+
# data set 0 "mMR_NEMA_IQ" - num views 252
484+
for num_subsets in [28]:
485+
test_petric(
486+
step_size=step_size,
487+
ds=0,
488+
num_iter=300,
489+
num_subsets=num_subsets,
490+
metric_period=num_subsets,
491+
complete_gradient_epochs=complete_gradient_epochs,
492+
precond_update_epochs=precond_update_epochs,
493+
precond_hessian_factor=phf,
494+
)
495+
496+
# # data set 1 "neuro LF" - num views 128
497+
# for num_subsets in [32]:
498+
# test_petric(
499+
# step_size=step_size,
500+
# ds=1,
501+
# num_iter=300,
502+
# num_subsets=num_subsets,
503+
# metric_period=num_subsets,
504+
# complete_gradient_epochs=complete_gradient_epochs,
505+
# precond_update_epochs=precond_update_epochs,
506+
# precond_hessian_factor=phf,
507+
# )
508+
#
509+
# # data set 2 "vision" - num views 50
510+
# for num_subsets in [25]:
511+
# test_petric(
512+
# step_size=step_size,
513+
# ds=2,
514+
# num_iter=300,
515+
# num_subsets=num_subsets,
516+
# metric_period=num_subsets,
517+
# complete_gradient_epochs=complete_gradient_epochs,
518+
# precond_update_epochs=precond_update_epochs,
519+
# precond_hessian_factor=phf,
520+
# )
521+
#
522+
# # data set 4 "mMR_ACR" - num views 252
523+
# for num_subsets in [28]:
524+
# test_petric(
525+
# step_size=step_size,
526+
# ds=3,
527+
# num_iter=300,
528+
# num_subsets=num_subsets,
529+
# metric_period=num_subsets,
530+
# complete_gradient_epochs=complete_gradient_epochs,
531+
# precond_update_epochs=precond_update_epochs,
532+
# precond_hessian_factor=phf,
533+
# )
534+
#

0 commit comments

Comments
 (0)