Skip to content

Commit

Permalink
add first version of RDP prox
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrhardt committed Aug 9, 2024
1 parent a8f27ac commit 7ac0e2b
Showing 1 changed file with 77 additions and 7 deletions.
84 changes: 77 additions & 7 deletions simulations/test_RDP.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
except ImportError:
import array_api_compat.numpy as xp

import abc
import parallelproj
import array_api_compat.numpy as np
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -54,13 +55,6 @@
# max number of updates for reference L-BFGS-B solution
num_iter_bfgs_ref = 400

# SDPHG parameters
rho = 3.0 # up to rho = 3 seems also to work for some gammas
# array of gamma values to try for SPDHG - these get divided by the "scale" of the OSEM image
gammas = np.array([0.3, 1.0, 3.0, 10.0])
# number of iterations to numerically evaluate the proximal operator of the prior (function G)
num_iter_prox_g = 20

# number of rings of simulated PET scanner, should be odd in this example
num_rings = 2
# resolution of the simulated PET scanner in mm
Expand Down Expand Up @@ -343,3 +337,79 @@ def shift(x, axis):
print('{:.2e}'.format(eps),
'{:.2e}'.format(abs_diff))
# %%



class ProxRDP(abc.ABC):

def __init__(
self,
rdp,
rdp_lipschitz_estimate: float = 1.0,
niter: int = 20,
tol: float = 1e-3,
min_iter: int = 5,
) -> None:

self._rdp = rdp
self._niter = niter
self._solution = None
self._rdp_lipschitz_estimate = rdp_lipschitz_estimate
self._tol = tol
self._min_iter = min_iter

def __call__(self, sigma: float, x: Array) -> float:

if self._solution is None:
z = x.copy()
else:
z = self._solution

for k in range(self._niter):
fun_val = self._rdp(z)
if k == 0:
g = self._rdp.prox_gradient(z, x, sigma)
else:
g = gnew

L = 1 + sigma * self._rdp_lipschitz_estimate
step_size = 1/L

z_new = z - step_size * g
z_new = xp.maximum(z_new, 0)

gnew = self._rdp.prox_gradient(z, x, sigma)
fun_val_new = self._rdp(z_new)

if fun_val_new < fun_val:
if k > self._min_iter:
#criterion = xp.linalg.norm(z_new - z)/xp.linalg.norm(z)
criterion = xp.linalg.norm(gnew)/xp.linalg.norm(g)

print(k, 'not converged', criterion, self._tol)
if criterion < self._tol:
self._solution = z_new
print(k, 'converged', criterion, self._tol)
return self._solution
z = z_new
self._rdp_lipschitz_estimate *= 0.8
print(k, 'good step', self._rdp_lipschitz_estimate)
else:
self._rdp_lipschitz_estimate *= 2
print(k, 'bad step', self._rdp_lipschitz_estimate)

self._solution = z
return self._solution

x = x_osem + 0.2*xp.random.randn(*x_osem.shape)

plt.figure()
plt.imshow(x[:,:,0]); plt.colorbar()

prox = ProxRDP(prior, rdp_lipschitz_estimate=100, niter=200)
x_clean = prox(1e-1, x)

plt.figure()
plt.imshow(x_clean[:,:,0]); plt.colorbar()

# %%

0 comments on commit 7ac0e2b

Please sign in to comment.