Skip to content

Commit

Permalink
add symmetric linear diff and sum operators
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrhardt committed Aug 9, 2024
1 parent 60cde82 commit 2aec66e
Show file tree
Hide file tree
Showing 2 changed files with 351 additions and 19 deletions.
264 changes: 264 additions & 0 deletions simulations/test_Ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
"""script to test RDP prior"""

# %%
from __future__ import annotations

try:
import array_api_compat.cupy as xp
except ImportError:
import array_api_compat.numpy as xp

import parallelproj
import array_api_compat.numpy as np
import matplotlib.pyplot as plt

from array_api_compat import to_device
from scipy.optimize import fmin_l_bfgs_b
from pathlib import Path

from utils import (
LinOp,
SubsetNegPoissonLogLWithPrior,
FwdDiff,
FwdDiffSymm,
FwdSum,
FwdSumSymm,
FwdMult,
RDP,
split_fwd_model,
OSEM,
rdp_preconditioner,
)

# choose a device (CPU or CUDA GPU)
if "numpy" in xp.__name__:
# using numpy, device must be cpu
dev = "cpu"
elif "cupy" in xp.__name__:
# using cupy, only cuda devices are possible
dev = xp.cuda.Device(0)

seed = 1

# true counts, reasonable range: 1e6, 1e7 (high counts), 1e5 (low counts)
true_counts = 1e6
# regularization weight, reasonable range: 0.14, 0.014, 1.4 for RDP
# regularization weight, reasonable range: for quad 30 * 0.14 for 1e6, 3*0.14 for 1e7, 300*0.14 for 1e5

beta = 3 * 0.14

# RDP gamma parameter
gamma_rdp = 2.0

# max number of updates for reference L-BFGS-B solution
num_iter_bfgs_ref = 400

# SDPHG parameters
rho = 3.0 # up to rho = 3 seems also to work for some gammas
# array of gamma values to try for SPDHG - these get divided by the "scale" of the OSEM image
gammas = np.array([0.3, 1.0, 3.0, 10.0])
# number of iterations to numerically evaluate the proximal operator of the prior (function G)
num_iter_prox_g = 20

# number of rings of simulated PET scanner, should be odd in this example
num_rings = 2
# resolution of the simulated PET scanner in mm
fwhm_data_mm = 4.5
# simulated TOF or non-TOF system
tof = False
# mean of contamination sinogram, relative to mean of trues sinogram, reasonable range: 0.5 - 1.0
contam_fraction = 0.5

# number of epochs / subsets for intial OSEM
num_epochs_osem = 1
num_subsets_osem = 54

# random seed
np.random.seed(seed)

# Setup of the forward model :math:`\bar{y}(x) = A x + s`
# --------------------------------------------------------
#
# We setup a linear forward operator :math:`A` consisting of an
# image-based resolution model, a non-TOF PET projector and an attenuation model
#
# .. note::
# The OSEM implementation below works with all linear operators that
# subclass :class:`.LinearOperator` (e.g. the high-level projectors).

scanner = parallelproj.RegularPolygonPETScannerGeometry(
xp,
dev,
radius=300.0,
num_sides=36,
num_lor_endpoints_per_side=12,
lor_spacing=4.0,
ring_positions=xp.linspace(
-5 * (num_rings - 1) / 2, 5 * (num_rings - 1) / 2, num_rings
),
symmetry_axis=2,
)

# setup the LOR descriptor that defines the sinogram

img_shape = (100, 100, 2 * num_rings - 1)
voxel_size = (2.0, 2.0, 2.0)

lor_desc = parallelproj.RegularPolygonPETLORDescriptor(
scanner,
radial_trim=140,
sinogram_order=parallelproj.SinogramSpatialAxisOrder.RVP,
)

proj = parallelproj.RegularPolygonPETProjector(
lor_desc, img_shape=img_shape, voxel_size=voxel_size
)

# setup a simple test image containing a few "hot rods"
x_true = xp.ones(proj.in_shape, device=dev, dtype=xp.float32)
c0 = proj.in_shape[0] // 2
c1 = proj.in_shape[1] // 2
x_true[(c0 - 4) : (c0 + 4), (c1 - 4) : (c1 + 4), :] = 3.0

x_true[28:32, c1 : (c1 + 4), :] = 5.0
x_true[c0 : (c0 + 4), 20:24, :] = 5.0

x_true[-32:-28, c1 : (c1 + 4), :] = 0.1
x_true[c0 : (c0 + 4), -24:-20, :] = 0.1

x_true[:25, :, :] = 0
x_true[-25:, :, :] = 0
x_true[:, :10, :] = 0
x_true[:, -10:, :] = 0

# Attenuation image and sinogram setup
# ------------------------------------

# setup an attenuation image
x_att = 0.01 * xp.astype(x_true > 0, xp.float32)
# calculate the attenuation sinogram
att_sino = xp.exp(-proj(x_att))

# Complete PET forward model setup
# --------------------------------
#
# We combine an image-based resolution model,
# a non-TOF or TOF PET projector and an attenuation model
# into a single linear operator.

# enable TOF - comment if you want to run non-TOF
if tof is True:
proj.tof_parameters = parallelproj.TOFParameters(
num_tofbins=13, tofbin_width=12.0, sigma_tof=12.0
)

# setup the attenuation multiplication operator which is different
# for TOF and non-TOF since the attenuation sinogram is always non-TOF
if proj.tof:
att_op = parallelproj.TOFNonTOFElementwiseMultiplicationOperator(
proj.out_shape, att_sino
)
else:
att_op = parallelproj.ElementwiseMultiplicationOperator(att_sino)

res_model = parallelproj.GaussianFilterOperator(
proj.in_shape, sigma=fwhm_data_mm / (2.35 * proj.voxel_size)
)

# compose all 3 operators into a single linear operator
pet_lin_op = parallelproj.CompositeLinearOperator((att_op, proj, res_model))

# Simulation of projection data
# -----------------------------
#
# We setup an arbitrary ground truth :math:`x_{true}` and simulate
# noise-free and noisy data :math:`y` by adding Poisson noise.

# simulated noise-free data
noise_free_data = pet_lin_op(x_true)

if true_counts > 0:
scale_fac = true_counts / float(xp.sum(noise_free_data))
noise_free_data *= scale_fac
x_true *= scale_fac

# generate a contant contamination sinogram
contamination = xp.full(
noise_free_data.shape,
contam_fraction * float(xp.mean(noise_free_data)),
device=dev,
dtype=xp.float32,
)

noise_free_data += contamination

# add Poisson noise
data = xp.asarray(
np.random.poisson(parallelproj.to_numpy_array(noise_free_data)),
device=dev,
dtype=xp.float32,
)

# run quick OSEM with one iteration

pet_subset_lin_op_seq_osem, subset_slices_osem = split_fwd_model(
pet_lin_op, num_subsets_osem
)

data_fidelity = SubsetNegPoissonLogLWithPrior(
data, pet_subset_lin_op_seq_osem, contamination, subset_slices_osem
)

x0 = xp.ones(pet_lin_op.in_shape, device=dev, dtype=xp.float32)
osem_alg = OSEM(data_fidelity)
x_osem = osem_alg.run(x0, num_epochs_osem)

# calculate the kappa image for the prior weights
# fwd_ones = pet_lin_op(xp.ones(pet_lin_op.in_shape, device=dev, dtype=xp.float32))
# fwd_osem = pet_lin_op(x_osem) + contamination
#
# kappa = xp.sqrt(pet_lin_op.adjoint((data * fwd_ones) / (fwd_osem**2)))

# setup of the cost function

fwd_diff = FwdDiff(img_shape, xp)
fwd_sum = FwdSum(img_shape, xp)
fwd_mult = FwdMult(img_shape, xp)

fwd_diff_symm = FwdDiffSymm(img_shape, xp)

fwd_ones = pet_lin_op(xp.ones(pet_lin_op.in_shape, device=dev, dtype=xp.float32))
fwd_osem = pet_lin_op(x_osem) + contamination
kappa = xp.sqrt(pet_lin_op.adjoint((data * fwd_ones) / (fwd_osem**2)))

prior = RDP(
fwd_diff,
fwd_sum,
eps=float(xp.max(x_osem)) / 10,
xp=xp,
dev=dev,
gamma=gamma_rdp,
)

prior.weights = xp.sqrt(fwd_mult(kappa))
prior.scale = beta


# %%
def adjoint_test(A, B):
x = xp.random.rand(*A.in_shape)
Ax = A(x)
y = xp.random.rand(*Ax.shape)
By = B(y)
AxTy = xp.sum(Ax * y)
ByTx = xp.sum(x * By)
print(AxTy, ByTx, AxTy - ByTx, AxTy / ByTx)

x = np.reshape(np.arange(24), (2,3,4))
fwd_diff_symm = FwdDiffSymm(x.shape, xp)
fwd_sum_symm = FwdSumSymm(x.shape, xp)

print('test diff symm')
adjoint_test(fwd_diff_symm, fwd_diff_symm.adjoint)
print('test sum symm')
adjoint_test(fwd_sum_symm, fwd_sum_symm.adjoint)
Loading

0 comments on commit 2aec66e

Please sign in to comment.