Skip to content

Commit 63d050e

Browse files
committed
ise new RDP also in simulations
1 parent b3179e3 commit 63d050e

File tree

2 files changed

+170
-152
lines changed

2 files changed

+170
-152
lines changed

simulations/stochastic_grad.py

+27-9
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,18 @@
1818

1919
from utils import (
2020
SubsetNegPoissonLogLWithPrior,
21-
RDP,
2221
split_fwd_model,
2322
OSEM,
2423
SGD,
2524
SVRG,
2625
rdp_preconditioner,
27-
neighbor_product,
2826
)
2927

28+
import sys
29+
30+
sys.path.append("../")
31+
from rdp import RDP, neighbor_product
32+
3033
# choose a device (CPU or CUDA GPU)
3134
if "numpy" in xp.__name__:
3235
# using numpy, device must be cpu
@@ -71,9 +74,9 @@
7174
svrg_gradient_recalc_periods = [x for x in range(0, num_epochs_sgd, 2)]
7275
# (initial) step sizes to try
7376
step_sizes = [
74-
0.5,
77+
0.3,
7578
1.0,
76-
2.0,
79+
3.0,
7780
] # step size 0 means that quick and dirty line search is used
7881

7982
# max number of updates for reference L-BFGS-B solution
@@ -289,14 +292,32 @@
289292
prior = RDP(
290293
img_shape,
291294
xp=xp,
292-
eps=float(xp.max(x_osem)) / 100,
293295
dev=dev,
296+
voxel_size=xp.asarray(voxel_size, device=dev),
297+
eps=float(xp.max(x_osem)) / 100,
294298
gamma=gamma_rdp,
295299
)
296300

297-
prior.weights = neighbor_product(kappa_img, xp)
301+
prior.kappa = kappa_img
298302
prior.scale = beta
299303

304+
adjoint_ones = pet_lin_op.adjoint(
305+
xp.ones(pet_lin_op.out_shape, device=dev, dtype=xp.float32)
306+
)
307+
308+
## %%
309+
# x = x_init.copy()
310+
# h = prior.diag_hessian(x)
311+
# d_data = to_device(adjoint_ones, "cpu")
312+
# d_prior = to_device(h * x, "cpu")
313+
#
314+
## %%
315+
# import pymirc.viewer as pv
316+
# vi = pv.ThreeAxisViewer([d_data, d_prior, d_data > d_prior])
317+
318+
# %%
319+
320+
300321
pet_subset_lin_op_seq, subset_slices = split_fwd_model(pet_lin_op, num_subsets_sgd)
301322

302323
cost_function = SubsetNegPoissonLogLWithPrior(
@@ -335,9 +356,6 @@
335356

336357
x_osem_scale = float(xp.mean(x_init))
337358

338-
adjoint_ones = pet_lin_op.adjoint(
339-
xp.ones(pet_lin_op.out_shape, device=dev, dtype=xp.float32)
340-
)
341359

342360
# %%
343361
cost_osem = cost_function(x_osem)

simulations/utils.py

+143-143
Original file line numberDiff line numberDiff line change
@@ -305,147 +305,147 @@ def _approx_diag_hessian(self, x: Array) -> Array:
305305
return diag_hes
306306

307307

308-
def neighbor_difference_and_sum(
309-
x: Array, xp: ModuleType, padding: str = "edge"
310-
) -> tuple[Array, Array]:
311-
"""get differences and sums with nearest neighbors for an n-dimensional array x
312-
using padding (by default in edge mode)
313-
a x.ndim*(3,) neighborhood around each element is used
314-
"""
315-
x_padded = xp.pad(x, 1, mode=padding)
316-
317-
# number of nearest neighbors
318-
num_neigh = 3**x.ndim - 1
319-
320-
# array for differences and sums with nearest neighbors
321-
d = xp.zeros((num_neigh,) + x.shape, dtype=x.dtype)
322-
s = xp.zeros((num_neigh,) + x.shape, dtype=x.dtype)
323-
324-
for i, ind in enumerate(xp.ndindex(x.ndim * (3,))):
325-
if i != (num_neigh // 2):
326-
sl = []
327-
for j in ind:
328-
if j - 2 < 0:
329-
sl.append(slice(j, j - 2))
330-
else:
331-
sl.append(slice(j, None))
332-
sl = tuple(sl)
333-
334-
if i < num_neigh // 2:
335-
d[i] = x - x_padded[sl]
336-
s[i] = x + x_padded[sl]
337-
else:
338-
d[i - 1] = x - x_padded[sl]
339-
s[i - 1] = x + x_padded[sl]
340-
341-
return d, s
342-
343-
344-
def neighbor_product(x: Array, xp: ModuleType, padding: str = "edge") -> Array:
345-
"""get backward and forward neighbor products for each dimension of an array x
346-
using padding (by default in edge mode)
347-
"""
348-
x_padded = xp.pad(x, 1, mode=padding)
349-
350-
# number of nearest neighbors
351-
num_neigh = 3**x.ndim - 1
352-
353-
# array for differences and sums with nearest neighbors
354-
p = xp.zeros((num_neigh,) + x.shape, dtype=x.dtype)
355-
356-
for i, ind in enumerate(xp.ndindex(x.ndim * (3,))):
357-
if i != (num_neigh // 2):
358-
sl = []
359-
for j in ind:
360-
if j - 2 < 0:
361-
sl.append(slice(j, j - 2))
362-
else:
363-
sl.append(slice(j, None))
364-
sl = tuple(sl)
365-
366-
if i < num_neigh // 2:
367-
p[i] = x * x_padded[sl]
368-
else:
369-
p[i - 1] = x * x_padded[sl]
370-
371-
return p
372-
373-
374-
class RDP(SmoothFunctionWithApproxHessian):
375-
def __init__(
376-
self,
377-
in_shape: tuple[int, ...],
378-
xp: ModuleType,
379-
dev: str,
380-
eps: float | None = None,
381-
gamma: float = 2.0,
382-
padding: str = "edge",
383-
) -> None:
384-
self._gamma = gamma
385-
386-
if eps is None:
387-
self._eps = xp.finfo(xp.float32).eps
388-
else:
389-
self._eps = eps
390-
391-
self._padding = padding
392-
393-
self._weights = None
394-
395-
super().__init__(in_shape=in_shape, xp=xp, dev=dev)
396-
397-
@property
398-
def gamma(self) -> float:
399-
return self._gamma
400-
401-
@property
402-
def eps(self) -> float:
403-
return self._eps
404-
405-
@property
406-
def weights(self) -> Array | None:
407-
return self._weights
408-
409-
@weights.setter
410-
def weights(self, weights: Array) -> None:
411-
self._weights = weights
412-
413-
def _call(self, x: Array) -> float:
414-
415-
if float(self.xp.min(x)) < 0:
416-
return self.xp.inf
417-
418-
d, s = neighbor_difference_and_sum(x, self.xp, padding=self._padding)
419-
phi = s + self.gamma * self.xp.abs(d) + self.eps
420-
421-
tmp = (d**2) / phi
422-
423-
if self._weights is not None:
424-
tmp *= self._weights
425-
426-
return float(self.xp.sum(tmp))
427-
428-
def _gradient(self, x: Array) -> Array:
429-
d, s = neighbor_difference_and_sum(x, self.xp, padding=self._padding)
430-
phi = s + self.gamma * self.xp.abs(d) + self.eps
431-
432-
tmp = d * (2 * phi - (d + self.gamma * self.xp.abs(d))) / (phi**2)
433-
434-
if self._weights is not None:
435-
tmp *= self._weights
436-
437-
return 2 * tmp.sum(axis=0)
438-
439-
def _approx_diag_hessian(self, x: Array) -> Array:
440-
d, s = neighbor_difference_and_sum(x, self.xp, padding=self._padding)
441-
phi = s + self.gamma * self.xp.abs(d) + self.eps
442-
443-
tmp = ((s - d + self.eps) ** 2) / (phi**3)
444-
445-
if self._weights is not None:
446-
tmp *= self._weights
447-
448-
return 4 * tmp.sum(axis=0)
308+
# def neighbor_difference_and_sum(
309+
# x: Array, xp: ModuleType, padding: str = "edge"
310+
# ) -> tuple[Array, Array]:
311+
# """get differences and sums with nearest neighbors for an n-dimensional array x
312+
# using padding (by default in edge mode)
313+
# a x.ndim*(3,) neighborhood around each element is used
314+
# """
315+
# x_padded = xp.pad(x, 1, mode=padding)
316+
#
317+
# # number of nearest neighbors
318+
# num_neigh = 3**x.ndim - 1
319+
#
320+
# # array for differences and sums with nearest neighbors
321+
# d = xp.zeros((num_neigh,) + x.shape, dtype=x.dtype)
322+
# s = xp.zeros((num_neigh,) + x.shape, dtype=x.dtype)
323+
#
324+
# for i, ind in enumerate(xp.ndindex(x.ndim * (3,))):
325+
# if i != (num_neigh // 2):
326+
# sl = []
327+
# for j in ind:
328+
# if j - 2 < 0:
329+
# sl.append(slice(j, j - 2))
330+
# else:
331+
# sl.append(slice(j, None))
332+
# sl = tuple(sl)
333+
#
334+
# if i < num_neigh // 2:
335+
# d[i] = x - x_padded[sl]
336+
# s[i] = x + x_padded[sl]
337+
# else:
338+
# d[i - 1] = x - x_padded[sl]
339+
# s[i - 1] = x + x_padded[sl]
340+
#
341+
# return d, s
342+
#
343+
#
344+
# def neighbor_product(x: Array, xp: ModuleType, padding: str = "edge") -> Array:
345+
# """get backward and forward neighbor products for each dimension of an array x
346+
# using padding (by default in edge mode)
347+
# """
348+
# x_padded = xp.pad(x, 1, mode=padding)
349+
#
350+
# # number of nearest neighbors
351+
# num_neigh = 3**x.ndim - 1
352+
#
353+
# # array for differences and sums with nearest neighbors
354+
# p = xp.zeros((num_neigh,) + x.shape, dtype=x.dtype)
355+
#
356+
# for i, ind in enumerate(xp.ndindex(x.ndim * (3,))):
357+
# if i != (num_neigh // 2):
358+
# sl = []
359+
# for j in ind:
360+
# if j - 2 < 0:
361+
# sl.append(slice(j, j - 2))
362+
# else:
363+
# sl.append(slice(j, None))
364+
# sl = tuple(sl)
365+
#
366+
# if i < num_neigh // 2:
367+
# p[i] = x * x_padded[sl]
368+
# else:
369+
# p[i - 1] = x * x_padded[sl]
370+
#
371+
# return p
372+
#
373+
#
374+
# class RDP(SmoothFunctionWithApproxHessian):
375+
# def __init__(
376+
# self,
377+
# in_shape: tuple[int, ...],
378+
# xp: ModuleType,
379+
# dev: str,
380+
# eps: float | None = None,
381+
# gamma: float = 2.0,
382+
# padding: str = "edge",
383+
# ) -> None:
384+
# self._gamma = gamma
385+
#
386+
# if eps is None:
387+
# self._eps = xp.finfo(xp.float32).eps
388+
# else:
389+
# self._eps = eps
390+
#
391+
# self._padding = padding
392+
#
393+
# self._weights = None
394+
#
395+
# super().__init__(in_shape=in_shape, xp=xp, dev=dev)
396+
#
397+
# @property
398+
# def gamma(self) -> float:
399+
# return self._gamma
400+
#
401+
# @property
402+
# def eps(self) -> float:
403+
# return self._eps
404+
#
405+
# @property
406+
# def weights(self) -> Array | None:
407+
# return self._weights
408+
#
409+
# @weights.setter
410+
# def weights(self, weights: Array) -> None:
411+
# self._weights = weights
412+
#
413+
# def _call(self, x: Array) -> float:
414+
#
415+
# if float(self.xp.min(x)) < 0:
416+
# return self.xp.inf
417+
#
418+
# d, s = neighbor_difference_and_sum(x, self.xp, padding=self._padding)
419+
# phi = s + self.gamma * self.xp.abs(d) + self.eps
420+
#
421+
# tmp = (d**2) / phi
422+
#
423+
# if self._weights is not None:
424+
# tmp *= self._weights
425+
#
426+
# return float(self.xp.sum(tmp))
427+
#
428+
# def _gradient(self, x: Array) -> Array:
429+
# d, s = neighbor_difference_and_sum(x, self.xp, padding=self._padding)
430+
# phi = s + self.gamma * self.xp.abs(d) + self.eps
431+
#
432+
# tmp = d * (2 * phi - (d + self.gamma * self.xp.abs(d))) / (phi**2)
433+
#
434+
# if self._weights is not None:
435+
# tmp *= self._weights
436+
#
437+
# return 2 * tmp.sum(axis=0)
438+
#
439+
# def _approx_diag_hessian(self, x: Array) -> Array:
440+
# d, s = neighbor_difference_and_sum(x, self.xp, padding=self._padding)
441+
# phi = s + self.gamma * self.xp.abs(d) + self.eps
442+
#
443+
# tmp = ((s - d + self.eps) ** 2) / (phi**3)
444+
#
445+
# if self._weights is not None:
446+
# tmp *= self._weights
447+
#
448+
# return 4 * tmp.sum(axis=0)
449449

450450

451451
class L2DataFidelity(SmoothFunction):
@@ -1129,14 +1129,14 @@ def split_fwd_model(
11291129
def rdp_preconditioner(
11301130
x: Array,
11311131
adjoint_ones: Array,
1132-
prior: SmoothFunctionWithApproxHessian,
1132+
prior,
11331133
version: int = 1,
11341134
delta: float = 1e-6,
11351135
) -> Array:
11361136
if version == 1:
11371137
precond = (x + delta) / adjoint_ones
11381138
elif version == 2:
1139-
precond = (x + delta) / (adjoint_ones + prior.approx_diag_hessian(x) * x)
1139+
precond = (x + delta) / (adjoint_ones + prior.diag_hessian(x) * x)
11401140
else:
11411141
raise ValueError("precond_version must be 1 or 2")
11421142

0 commit comments

Comments
 (0)