diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..2e51c7e --- /dev/null +++ b/environment.yml @@ -0,0 +1,5 @@ +name: magez +channels: [conda-forge] +dependencies: +- array-api-compat +- cupy \ No newline at end of file diff --git a/main_SVRG.py b/main_SVRG.py index bb9a62a..ba98459 100644 --- a/main_SVRG.py +++ b/main_SVRG.py @@ -12,18 +12,18 @@ import sirf.STIR as STIR from cil.optimisation.algorithms import Algorithm from cil.optimisation.utilities import callbacks -from petric import Dataset from sirf.contrib.partitioner.partitioner import partition_indices from sirf.contrib.partitioner import partitioner import numpy as np -import array_api_compat.numpy as xp - -from scipy.ndimage import gaussian_filter +import array_api_compat.cupy as xp +from array_api_compat import to_device # import pure python re-implementation of the RDP -> only used to get diagonal of the RDP Hessian! from rdp import RDP +from petric import Dataset + def get_divisors(n): """Returns a sorted list of all divisors of a positive integer n.""" @@ -158,15 +158,20 @@ def __init__( # setup python re-implementation of the RDP # only used to get the diagonal of the RDP Hessian for preconditioning! # (diag of RDP Hessian is not available in SIRF yet) + if "cupy" in xp.__name__: + self._dev = xp.cuda.Device(0) + else: + self._dev = "cpu" + self._python_prior = RDP( data.OSEM_image.shape, xp, - "cpu", - xp.asarray(data.OSEM_image.spacing, device="cpu"), + self._dev, + xp.asarray(data.OSEM_image.spacing, device=self._dev), eps=data.prior.get_epsilon(), gamma=data.prior.get_gamma(), ) - self._python_prior.kappa = data.kappa.as_array() + self._python_prior.kappa = xp.asarray(data.kappa.as_array(), device=self._dev) self._python_prior.scale = data.prior.get_penalisation_factor() self._precond_filter = STIR.SeparableGaussianImageFilter() @@ -198,7 +203,14 @@ def calc_precond( delta = delta_rel * x_sm.max() prior_diag_hess = x_sm.get_uniform_copy(0) - prior_diag_hess.fill(self._python_prior.diag_hessian(x_sm.as_array())) + prior_diag_hess.fill( + to_device( + self._python_prior.diag_hessian( + xp.asarray(x_sm.as_array(), device=self._dev) + ), + "cpu", + ) + ) precond = ( self._fov_mask diff --git a/rdp.py b/rdp.py index 8ad1fb6..7232f84 100644 --- a/rdp.py +++ b/rdp.py @@ -1,6 +1,6 @@ import abc import array_api_compat.numpy as np -from array_api_compat import get_namespace, device +from array_api_compat import device from types import ModuleType from typing import TypeAlias diff --git a/test_petric.py b/test_petric.py index 68f1129..7595ad0 100644 --- a/test_petric.py +++ b/test_petric.py @@ -418,4 +418,4 @@ def test_petric(ds: int, num_iter: int, **kwargs): ) else: for i in range(4): - test_petric(ds=i, num_iter=300, verbose=True) + test_petric(ds=i, num_iter=300)