Skip to content

Commit

Permalink
clean up + add env.
Browse files Browse the repository at this point in the history
  • Loading branch information
gschramm committed Aug 30, 2024
1 parent 7f7c288 commit 3caaaea
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 10 deletions.
5 changes: 5 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
name: magez
channels: [conda-forge]
dependencies:
- array-api-compat
- cupy
28 changes: 20 additions & 8 deletions main_SVRG.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@
import sirf.STIR as STIR
from cil.optimisation.algorithms import Algorithm
from cil.optimisation.utilities import callbacks
from petric import Dataset
from sirf.contrib.partitioner.partitioner import partition_indices
from sirf.contrib.partitioner import partitioner

import numpy as np
import array_api_compat.numpy as xp

from scipy.ndimage import gaussian_filter
import array_api_compat.cupy as xp
from array_api_compat import to_device

# import pure python re-implementation of the RDP -> only used to get diagonal of the RDP Hessian!
from rdp import RDP

from petric import Dataset


def get_divisors(n):
"""Returns a sorted list of all divisors of a positive integer n."""
Expand Down Expand Up @@ -158,15 +158,20 @@ def __init__(
# setup python re-implementation of the RDP
# only used to get the diagonal of the RDP Hessian for preconditioning!
# (diag of RDP Hessian is not available in SIRF yet)
if "cupy" in xp.__name__:
self._dev = xp.cuda.Device(0)
else:
self._dev = "cpu"

self._python_prior = RDP(
data.OSEM_image.shape,
xp,
"cpu",
xp.asarray(data.OSEM_image.spacing, device="cpu"),
self._dev,
xp.asarray(data.OSEM_image.spacing, device=self._dev),
eps=data.prior.get_epsilon(),
gamma=data.prior.get_gamma(),
)
self._python_prior.kappa = data.kappa.as_array()
self._python_prior.kappa = xp.asarray(data.kappa.as_array(), device=self._dev)
self._python_prior.scale = data.prior.get_penalisation_factor()

self._precond_filter = STIR.SeparableGaussianImageFilter()
Expand Down Expand Up @@ -198,7 +203,14 @@ def calc_precond(
delta = delta_rel * x_sm.max()

prior_diag_hess = x_sm.get_uniform_copy(0)
prior_diag_hess.fill(self._python_prior.diag_hessian(x_sm.as_array()))
prior_diag_hess.fill(
to_device(
self._python_prior.diag_hessian(
xp.asarray(x_sm.as_array(), device=self._dev)
),
"cpu",
)
)

precond = (
self._fov_mask
Expand Down
2 changes: 1 addition & 1 deletion rdp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
import array_api_compat.numpy as np
from array_api_compat import get_namespace, device
from array_api_compat import device

from types import ModuleType
from typing import TypeAlias
Expand Down
2 changes: 1 addition & 1 deletion test_petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,4 +418,4 @@ def test_petric(ds: int, num_iter: int, **kwargs):
)
else:
for i in range(4):
test_petric(ds=i, num_iter=300, verbose=True)
test_petric(ds=i, num_iter=300)

0 comments on commit 3caaaea

Please sign in to comment.