From 7ac0e2bc4ecf60157dd35531bfe8d6018162fcac Mon Sep 17 00:00:00 2001 From: Matthias Ehrhardt Date: Fri, 9 Aug 2024 18:19:33 +0200 Subject: [PATCH] add first version of RDP prox --- simulations/test_RDP.py | 84 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 77 insertions(+), 7 deletions(-) diff --git a/simulations/test_RDP.py b/simulations/test_RDP.py index f7c58ed..6e24621 100644 --- a/simulations/test_RDP.py +++ b/simulations/test_RDP.py @@ -8,6 +8,7 @@ except ImportError: import array_api_compat.numpy as xp +import abc import parallelproj import array_api_compat.numpy as np import matplotlib.pyplot as plt @@ -54,13 +55,6 @@ # max number of updates for reference L-BFGS-B solution num_iter_bfgs_ref = 400 -# SDPHG parameters -rho = 3.0 # up to rho = 3 seems also to work for some gammas -# array of gamma values to try for SPDHG - these get divided by the "scale" of the OSEM image -gammas = np.array([0.3, 1.0, 3.0, 10.0]) -# number of iterations to numerically evaluate the proximal operator of the prior (function G) -num_iter_prox_g = 20 - # number of rings of simulated PET scanner, should be odd in this example num_rings = 2 # resolution of the simulated PET scanner in mm @@ -343,3 +337,79 @@ def shift(x, axis): print('{:.2e}'.format(eps), '{:.2e}'.format(abs_diff)) # %% + + + +class ProxRDP(abc.ABC): + + def __init__( + self, + rdp, + rdp_lipschitz_estimate: float = 1.0, + niter: int = 20, + tol: float = 1e-3, + min_iter: int = 5, + ) -> None: + + self._rdp = rdp + self._niter = niter + self._solution = None + self._rdp_lipschitz_estimate = rdp_lipschitz_estimate + self._tol = tol + self._min_iter = min_iter + + def __call__(self, sigma: float, x: Array) -> float: + + if self._solution is None: + z = x.copy() + else: + z = self._solution + + for k in range(self._niter): + fun_val = self._rdp(z) + if k == 0: + g = self._rdp.prox_gradient(z, x, sigma) + else: + g = gnew + + L = 1 + sigma * self._rdp_lipschitz_estimate + step_size = 1/L + + z_new = z - step_size * g + z_new = xp.maximum(z_new, 0) + + gnew = self._rdp.prox_gradient(z, x, sigma) + fun_val_new = self._rdp(z_new) + + if fun_val_new < fun_val: + if k > self._min_iter: + #criterion = xp.linalg.norm(z_new - z)/xp.linalg.norm(z) + criterion = xp.linalg.norm(gnew)/xp.linalg.norm(g) + + print(k, 'not converged', criterion, self._tol) + if criterion < self._tol: + self._solution = z_new + print(k, 'converged', criterion, self._tol) + return self._solution + z = z_new + self._rdp_lipschitz_estimate *= 0.8 + print(k, 'good step', self._rdp_lipschitz_estimate) + else: + self._rdp_lipschitz_estimate *= 2 + print(k, 'bad step', self._rdp_lipschitz_estimate) + + self._solution = z + return self._solution + +x = x_osem + 0.2*xp.random.randn(*x_osem.shape) + +plt.figure() +plt.imshow(x[:,:,0]); plt.colorbar() + +prox = ProxRDP(prior, rdp_lipschitz_estimate=100, niter=200) +x_clean = prox(1e-1, x) + +plt.figure() +plt.imshow(x_clean[:,:,0]); plt.colorbar() + +# %%