-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add symmetric linear diff and sum operators
- Loading branch information
Showing
2 changed files
with
351 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,264 @@ | ||
"""script to test RDP prior""" | ||
|
||
# %% | ||
from __future__ import annotations | ||
|
||
try: | ||
import array_api_compat.cupy as xp | ||
except ImportError: | ||
import array_api_compat.numpy as xp | ||
|
||
import parallelproj | ||
import array_api_compat.numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
from array_api_compat import to_device | ||
from scipy.optimize import fmin_l_bfgs_b | ||
from pathlib import Path | ||
|
||
from utils import ( | ||
LinOp, | ||
SubsetNegPoissonLogLWithPrior, | ||
FwdDiff, | ||
FwdDiffSymm, | ||
FwdSum, | ||
FwdSumSymm, | ||
FwdMult, | ||
RDP, | ||
split_fwd_model, | ||
OSEM, | ||
rdp_preconditioner, | ||
) | ||
|
||
# choose a device (CPU or CUDA GPU) | ||
if "numpy" in xp.__name__: | ||
# using numpy, device must be cpu | ||
dev = "cpu" | ||
elif "cupy" in xp.__name__: | ||
# using cupy, only cuda devices are possible | ||
dev = xp.cuda.Device(0) | ||
|
||
seed = 1 | ||
|
||
# true counts, reasonable range: 1e6, 1e7 (high counts), 1e5 (low counts) | ||
true_counts = 1e6 | ||
# regularization weight, reasonable range: 0.14, 0.014, 1.4 for RDP | ||
# regularization weight, reasonable range: for quad 30 * 0.14 for 1e6, 3*0.14 for 1e7, 300*0.14 for 1e5 | ||
|
||
beta = 3 * 0.14 | ||
|
||
# RDP gamma parameter | ||
gamma_rdp = 2.0 | ||
|
||
# max number of updates for reference L-BFGS-B solution | ||
num_iter_bfgs_ref = 400 | ||
|
||
# SDPHG parameters | ||
rho = 3.0 # up to rho = 3 seems also to work for some gammas | ||
# array of gamma values to try for SPDHG - these get divided by the "scale" of the OSEM image | ||
gammas = np.array([0.3, 1.0, 3.0, 10.0]) | ||
# number of iterations to numerically evaluate the proximal operator of the prior (function G) | ||
num_iter_prox_g = 20 | ||
|
||
# number of rings of simulated PET scanner, should be odd in this example | ||
num_rings = 2 | ||
# resolution of the simulated PET scanner in mm | ||
fwhm_data_mm = 4.5 | ||
# simulated TOF or non-TOF system | ||
tof = False | ||
# mean of contamination sinogram, relative to mean of trues sinogram, reasonable range: 0.5 - 1.0 | ||
contam_fraction = 0.5 | ||
|
||
# number of epochs / subsets for intial OSEM | ||
num_epochs_osem = 1 | ||
num_subsets_osem = 54 | ||
|
||
# random seed | ||
np.random.seed(seed) | ||
|
||
# Setup of the forward model :math:`\bar{y}(x) = A x + s` | ||
# -------------------------------------------------------- | ||
# | ||
# We setup a linear forward operator :math:`A` consisting of an | ||
# image-based resolution model, a non-TOF PET projector and an attenuation model | ||
# | ||
# .. note:: | ||
# The OSEM implementation below works with all linear operators that | ||
# subclass :class:`.LinearOperator` (e.g. the high-level projectors). | ||
|
||
scanner = parallelproj.RegularPolygonPETScannerGeometry( | ||
xp, | ||
dev, | ||
radius=300.0, | ||
num_sides=36, | ||
num_lor_endpoints_per_side=12, | ||
lor_spacing=4.0, | ||
ring_positions=xp.linspace( | ||
-5 * (num_rings - 1) / 2, 5 * (num_rings - 1) / 2, num_rings | ||
), | ||
symmetry_axis=2, | ||
) | ||
|
||
# setup the LOR descriptor that defines the sinogram | ||
|
||
img_shape = (100, 100, 2 * num_rings - 1) | ||
voxel_size = (2.0, 2.0, 2.0) | ||
|
||
lor_desc = parallelproj.RegularPolygonPETLORDescriptor( | ||
scanner, | ||
radial_trim=140, | ||
sinogram_order=parallelproj.SinogramSpatialAxisOrder.RVP, | ||
) | ||
|
||
proj = parallelproj.RegularPolygonPETProjector( | ||
lor_desc, img_shape=img_shape, voxel_size=voxel_size | ||
) | ||
|
||
# setup a simple test image containing a few "hot rods" | ||
x_true = xp.ones(proj.in_shape, device=dev, dtype=xp.float32) | ||
c0 = proj.in_shape[0] // 2 | ||
c1 = proj.in_shape[1] // 2 | ||
x_true[(c0 - 4) : (c0 + 4), (c1 - 4) : (c1 + 4), :] = 3.0 | ||
|
||
x_true[28:32, c1 : (c1 + 4), :] = 5.0 | ||
x_true[c0 : (c0 + 4), 20:24, :] = 5.0 | ||
|
||
x_true[-32:-28, c1 : (c1 + 4), :] = 0.1 | ||
x_true[c0 : (c0 + 4), -24:-20, :] = 0.1 | ||
|
||
x_true[:25, :, :] = 0 | ||
x_true[-25:, :, :] = 0 | ||
x_true[:, :10, :] = 0 | ||
x_true[:, -10:, :] = 0 | ||
|
||
# Attenuation image and sinogram setup | ||
# ------------------------------------ | ||
|
||
# setup an attenuation image | ||
x_att = 0.01 * xp.astype(x_true > 0, xp.float32) | ||
# calculate the attenuation sinogram | ||
att_sino = xp.exp(-proj(x_att)) | ||
|
||
# Complete PET forward model setup | ||
# -------------------------------- | ||
# | ||
# We combine an image-based resolution model, | ||
# a non-TOF or TOF PET projector and an attenuation model | ||
# into a single linear operator. | ||
|
||
# enable TOF - comment if you want to run non-TOF | ||
if tof is True: | ||
proj.tof_parameters = parallelproj.TOFParameters( | ||
num_tofbins=13, tofbin_width=12.0, sigma_tof=12.0 | ||
) | ||
|
||
# setup the attenuation multiplication operator which is different | ||
# for TOF and non-TOF since the attenuation sinogram is always non-TOF | ||
if proj.tof: | ||
att_op = parallelproj.TOFNonTOFElementwiseMultiplicationOperator( | ||
proj.out_shape, att_sino | ||
) | ||
else: | ||
att_op = parallelproj.ElementwiseMultiplicationOperator(att_sino) | ||
|
||
res_model = parallelproj.GaussianFilterOperator( | ||
proj.in_shape, sigma=fwhm_data_mm / (2.35 * proj.voxel_size) | ||
) | ||
|
||
# compose all 3 operators into a single linear operator | ||
pet_lin_op = parallelproj.CompositeLinearOperator((att_op, proj, res_model)) | ||
|
||
# Simulation of projection data | ||
# ----------------------------- | ||
# | ||
# We setup an arbitrary ground truth :math:`x_{true}` and simulate | ||
# noise-free and noisy data :math:`y` by adding Poisson noise. | ||
|
||
# simulated noise-free data | ||
noise_free_data = pet_lin_op(x_true) | ||
|
||
if true_counts > 0: | ||
scale_fac = true_counts / float(xp.sum(noise_free_data)) | ||
noise_free_data *= scale_fac | ||
x_true *= scale_fac | ||
|
||
# generate a contant contamination sinogram | ||
contamination = xp.full( | ||
noise_free_data.shape, | ||
contam_fraction * float(xp.mean(noise_free_data)), | ||
device=dev, | ||
dtype=xp.float32, | ||
) | ||
|
||
noise_free_data += contamination | ||
|
||
# add Poisson noise | ||
data = xp.asarray( | ||
np.random.poisson(parallelproj.to_numpy_array(noise_free_data)), | ||
device=dev, | ||
dtype=xp.float32, | ||
) | ||
|
||
# run quick OSEM with one iteration | ||
|
||
pet_subset_lin_op_seq_osem, subset_slices_osem = split_fwd_model( | ||
pet_lin_op, num_subsets_osem | ||
) | ||
|
||
data_fidelity = SubsetNegPoissonLogLWithPrior( | ||
data, pet_subset_lin_op_seq_osem, contamination, subset_slices_osem | ||
) | ||
|
||
x0 = xp.ones(pet_lin_op.in_shape, device=dev, dtype=xp.float32) | ||
osem_alg = OSEM(data_fidelity) | ||
x_osem = osem_alg.run(x0, num_epochs_osem) | ||
|
||
# calculate the kappa image for the prior weights | ||
# fwd_ones = pet_lin_op(xp.ones(pet_lin_op.in_shape, device=dev, dtype=xp.float32)) | ||
# fwd_osem = pet_lin_op(x_osem) + contamination | ||
# | ||
# kappa = xp.sqrt(pet_lin_op.adjoint((data * fwd_ones) / (fwd_osem**2))) | ||
|
||
# setup of the cost function | ||
|
||
fwd_diff = FwdDiff(img_shape, xp) | ||
fwd_sum = FwdSum(img_shape, xp) | ||
fwd_mult = FwdMult(img_shape, xp) | ||
|
||
fwd_diff_symm = FwdDiffSymm(img_shape, xp) | ||
|
||
fwd_ones = pet_lin_op(xp.ones(pet_lin_op.in_shape, device=dev, dtype=xp.float32)) | ||
fwd_osem = pet_lin_op(x_osem) + contamination | ||
kappa = xp.sqrt(pet_lin_op.adjoint((data * fwd_ones) / (fwd_osem**2))) | ||
|
||
prior = RDP( | ||
fwd_diff, | ||
fwd_sum, | ||
eps=float(xp.max(x_osem)) / 10, | ||
xp=xp, | ||
dev=dev, | ||
gamma=gamma_rdp, | ||
) | ||
|
||
prior.weights = xp.sqrt(fwd_mult(kappa)) | ||
prior.scale = beta | ||
|
||
|
||
# %% | ||
def adjoint_test(A, B): | ||
x = xp.random.rand(*A.in_shape) | ||
Ax = A(x) | ||
y = xp.random.rand(*Ax.shape) | ||
By = B(y) | ||
AxTy = xp.sum(Ax * y) | ||
ByTx = xp.sum(x * By) | ||
print(AxTy, ByTx, AxTy - ByTx, AxTy / ByTx) | ||
|
||
x = np.reshape(np.arange(24), (2,3,4)) | ||
fwd_diff_symm = FwdDiffSymm(x.shape, xp) | ||
fwd_sum_symm = FwdSumSymm(x.shape, xp) | ||
|
||
print('test diff symm') | ||
adjoint_test(fwd_diff_symm, fwd_diff_symm.adjoint) | ||
print('test sum symm') | ||
adjoint_test(fwd_sum_symm, fwd_sum_symm.adjoint) |
Oops, something went wrong.