Skip to content

Commit

Permalink
refactor + cleanup code
Browse files Browse the repository at this point in the history
  • Loading branch information
gschramm committed Sep 27, 2024
1 parent d6d6cd0 commit e6360c8
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 23 deletions.
51 changes: 29 additions & 22 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from cil.optimisation.utilities import callbacks
from sirf.contrib.partitioner import partitioner

from collections.abc import Callable

import numpy as np
import array_api_compat.cupy as xp
from array_api_compat import to_device
Expand All @@ -34,6 +36,21 @@ def get_divisors(n):
return sorted(divisors)


def step_size_rule_1(update: int) -> float:
if update <= 10:
new_step_size = 3.0
elif update > 10 and update <= (4 * 25):
new_step_size = 2.0
elif update > (4 * 25) and update <= (8 * 25):
new_step_size = 1.5
elif update > (8 * 25) and update <= (12 * 25):
new_step_size = 1.0
else:
new_step_size = 0.5

return new_step_size


class MaxIteration(callbacks.Callback):
"""
The organisers try to `Submission(data).run(inf)` i.e. for infinite iterations (until timeout).
Expand Down Expand Up @@ -62,12 +79,12 @@ class Submission(Algorithm):
def __init__(
self,
data: Dataset,
step_size_factor: float = 1.0, # multiplicative factor to increase / decrease default step sizes
approx_num_subsets: int = 25, # approximate number of subsets, closest divisor of num_views will be used
update_objective_interval: int | None = None,
complete_gradient_epochs: None | list[int] = None,
complete_gradient_epochs: list[int] = [x for x in range(0, 1000, 2)],
step_size_update_function: Callable[[int], float] = step_size_rule_1,
precond_update_epochs: None | list[int] = None,
precond_hessian_factor: float = 0.75,
precond_hessian_factor: float = 1.5,
precond_filter_fwhm_mm: float = 5.0,
verbose: bool = False,
seed: int = 1,
Expand Down Expand Up @@ -104,8 +121,8 @@ def __init__(
self.x = data.OSEM_image.clone()

self._update = 0
self._step_size_factor = step_size_factor
self._step_size = self._step_size_factor * 2.0
self._step_size_update_function = step_size_update_function
self._step_size = self._step_size_update_function(self._update)
self._subset_number_list = []
self._precond_hessian_factor = precond_hessian_factor

Expand Down Expand Up @@ -202,19 +219,6 @@ def __init__(
def epoch(self):
return self._update // self._num_subsets

def update_step_size(self):
if self.epoch <= 4:
self._step_size = self._step_size_factor * 2.0
elif self.epoch > 4 and self.epoch <= 8:
self._step_size = self._step_size_factor * 1.5
elif self.epoch > 8 and self.epoch <= 12:
self._step_size = self._step_size_factor * 1.0
else:
self._step_size = self._step_size_factor * 0.5

if self._verbose:
print(self._update, self.epoch, self._step_size)

def calc_precond(
self,
x: STIR.ImageData,
Expand Down Expand Up @@ -242,7 +246,7 @@ def calc_precond(
* x_sm
/ (
self._adjoint_ones
+ (self._precond_hessian_factor * 2) * prior_diag_hess * x_sm
+ self._precond_hessian_factor * prior_diag_hess * x_sm
)
)

Expand Down Expand Up @@ -274,8 +278,11 @@ def update(self):
self._update % self._num_subsets == 0
) and self.epoch in self._precond_update_epochs

if self._update % self._num_subsets == 0:
self.update_step_size()
# update the step size based on the current update number and the current step size
self._step_size = self._step_size_update_function(self._update)

if self._verbose:
print(self._update, self._step_size)

if update_precond:
if self._verbose:
Expand Down Expand Up @@ -335,4 +342,4 @@ def create_subset_number_list(self):
self._subset_number_list = tmp.tolist()


submission_callbacks = [MaxIteration(300)]
submission_callbacks = [MaxIteration(660)]
5 changes: 4 additions & 1 deletion test_petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,9 @@ def test_petric(ds: int, num_iter: int, suffix: str = "", **kwargs):
elif ds == 5:
srcdir = SRCDIR / "Siemens_mMR_NEMA_IQ_lowcounts"
outdir = OUTDIR / "mMR_NEMA_lowcounts" / sdir_name
elif ds == 6:
srcdir = SRCDIR / "GE_DMI3_Torso"
outdir = OUTDIR / "GE_DMI3_Torso" / sdir_name
else:
raise ValueError(f"Unknown data set {ds}")

Expand Down Expand Up @@ -424,7 +427,7 @@ def test_petric(ds: int, num_iter: int, suffix: str = "", **kwargs):
)
else:
for ns in [25]:
for i in [5]:
for i in range(7):
test_petric(
ds=i,
num_iter=200,
Expand Down

0 comments on commit e6360c8

Please sign in to comment.