Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
gschramm committed Aug 27, 2024
1 parent 25c20ca commit e02695b
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 54 deletions.
57 changes: 46 additions & 11 deletions main_SVRG.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
from sirf.contrib.partitioner import partitioner

import numpy as np
import array_api_compat.numpy as xp

from scipy.ndimage import gaussian_filter

# import pure python re-implementation of the RDP -> only used to get diagonal of the RDP Hessian!
from rdp import RDP


class MaxIteration(callbacks.Callback):
Expand Down Expand Up @@ -61,7 +67,12 @@ def __init__(
"""

self.subset = 0

# setup the initial image as a slightly smoothed version of the OSEM image
self.x = data.OSEM_image.clone()
# sig = 5.0 / (2.35 * np.array(data.OSEM_image.spacing))
# self.x.fill(gaussian_filter(self.x.as_array(), sig))

self._verbose = verbose

self._num_subsets = num_subsets
Expand Down Expand Up @@ -101,39 +112,63 @@ def __init__(
tmp = 1.0 * (self._adjoint_ones.as_array() > 0)
self._fov_mask.fill(tmp)

# add a small number to avoid NaN in division
self._adjoint_ones += self._adjoint_ones.max() * 1e-6

# calculate the initial preconditioner based on the initial image
self._precond = self.calc_data_fidelity_precond(self.x)
# add a small number in the adjoint ones outside the FOV to avoid NaN in division
self._adjoint_ones += 1e-6 * (-self._fov_mask + 1.0)

# initialize list / ImageData for all subset gradients and sum of gradients
self._summed_subset_gradients = self.x.get_uniform_copy(0)
self._subset_gradients = []

self._precond_update_epochs: list[int] = [2]

if complete_gradient_epochs is None:
self._complete_gradient_epochs: list[int] = [x for x in range(0, 100, 2)]
else:
self._complete_gradient_epochs = complete_gradient_epochs

if precond_update_epochs is None:
self._precond_update_epochs: list[int] = [2]
self._precond_update_epochs: list[int] = [1]
else:
self._precond_update_epochs = precond_update_epochs

# setup python re-implementation of the RDP
# only used to get the diagonal of the RDP Hessian for preconditioning!
# (diag of RDP Hessian is not available in SIRF yet)
self._python_prior = RDP(
data.OSEM_image.shape,
xp,
"cpu",
xp.asarray(data.OSEM_image.spacing, device="cpu"),
eps=data.prior.get_epsilon(),
gamma=data.prior.get_gamma(),
)
self._python_prior.kappa = data.kappa.as_array()
self._python_prior.scale = data.prior.get_penalisation_factor()

# calculate the initial preconditioner based on the initial image
self._precond = self.calc_precond(self.x)

super().__init__(update_objective_interval=update_objective_interval, **kwargs)
self.configured = True # required by Algorithm

@property
def epoch(self):
return self._update // self._num_subsets

def calc_data_fidelity_precond(
# def calc_data_fidelity_precond(
# self, x: STIR.ImageData, delta_rel: float = 1e-6
# ) -> STIR.ImageData:
# return (x + delta_rel * x.max()) / self._adjoint_ones

def calc_precond(
self, x: STIR.ImageData, delta_rel: float = 1e-6
) -> STIR.ImageData:
return (x + delta_rel * x.max()) / self._adjoint_ones

delta = delta_rel * x.max()

prior_diag_hess = 0.0 * x
tmp = self._python_prior.diag_hessian(x.as_array())
prior_diag_hess.fill(1 / (tmp + delta_rel))

return (x + delta) / (self._adjoint_ones + prior_diag_hess * x)

def update_all_subset_gradients(self) -> None:

Expand All @@ -157,7 +192,7 @@ def update(self):
if update_precond:
if self._verbose:
print(f" {self._update}, updating preconditioner")
self._precond = self.calc_data_fidelity_precond(self.x)
self._precond = self.calc_precond(self.x)

if update_all_subset_gradients:
if self._verbose:
Expand Down
58 changes: 29 additions & 29 deletions test_petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,39 +481,39 @@ def test_petric(
precond_update_epochs=precond_update_epochs,
)
else:
for step_size in [2.0, 0.3]:
# # data set 0 "mMR_NEMA_IQ" - num views 252
# for num_subsets in [2, 9, 28, 42, 63]:
# test_petric(
# step_size,
# 0,
# 200,
# num_subsets,
# int(args["--metric_period"]),
# complete_gradient_epochs=complete_gradient_epochs,
# precond_update_epochs=precond_update_epochs,
# )
#
# # data set 1 "neuro LF" - num views 128
# for num_subsets in [2, 8, 32, 64]:
# test_petric(
# step_size,
# 1,
# 200,
# num_subsets,
# int(args["--metric_period"]),
# complete_gradient_epochs=complete_gradient_epochs,
# precond_update_epochs=precond_update_epochs,
# )

# data set 2 "vision" - num views 50
for num_subsets in [5, 10, 25]:
for step_size in [0.3, 0.5, 1.0]:
# data set 0 "mMR_NEMA_IQ" - num views 252
for num_subsets in [9]:
test_petric(
step_size,
2,
100,
0,
200,
num_subsets,
int(args["--metric_period"]),
complete_gradient_epochs=complete_gradient_epochs,
precond_update_epochs=precond_update_epochs,
)

## data set 1 "neuro LF" - num views 128
# for num_subsets in [8, 32, 64]:
# test_petric(
# step_size,
# 1,
# 200,
# num_subsets,
# int(args["--metric_period"]),
# complete_gradient_epochs=complete_gradient_epochs,
# precond_update_epochs=precond_update_epochs,
# )

## data set 2 "vision" - num views 50
# for num_subsets in [5, 10, 25]:
# test_petric(
# step_size,
# 2,
# 100,
# num_subsets,
# int(args["--metric_period"]),
# complete_gradient_epochs=complete_gradient_epochs,
# precond_update_epochs=precond_update_epochs,
# )
29 changes: 15 additions & 14 deletions test_rdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,28 +388,29 @@ def get_image(fname):
# x_arr = np.pad(x_arr[1:-1, 1:-1, 1:-1], 1)
x.fill(x_arr)

prior = data.prior

eps = prior.get_epsilon()
beta = prior.get_penalisation_factor()
gamma = prior.get_gamma()

# the RDP implementation can use numpy or cupy. for the latter you have to adjust the device
prior2 = RDP(x.shape, xp, "cpu", xp.asarray(x.spacing), eps=eps, gamma=gamma)
prior2.kappa = data.kappa.as_array()
prior2.scale = beta
python_prior = RDP(
x.shape,
xp,
"cpu",
xp.asarray(x.spacing, device="cpu"),
eps=data.prior.get_epsilon(),
gamma=data.prior.get_gamma(),
)
python_prior.kappa = data.kappa.as_array()
python_prior.scale = data.prior.get_penalisation_factor()

v1 = prior(x)
v2 = prior2(x.as_array())
v1 = data.prior(x)
v2 = python_prior(x.as_array())

g1 = prior.gradient(x).as_array()
g2 = prior2.gradient(x.as_array())
g1 = data.prior.gradient(x).as_array()
g2 = python_prior.gradient(x.as_array())

# assert np.isclose(v1, v2)
# assert np.all(np.isclose(g1, g2, atol=np.abs(g1).max() / 1e6))

# get the diagonal of the Hessian of the RDP (as numpy array)
h = prior2.diag_hessian(x.as_array())
h = python_prior.diag_hessian(x.as_array())

# %%

Expand Down

0 comments on commit e02695b

Please sign in to comment.