Skip to content

Commit

Permalink
add "simple" step size decay
Browse files Browse the repository at this point in the history
  • Loading branch information
gschramm committed Aug 30, 2024
1 parent 3caaaea commit 2cac36a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
24 changes: 20 additions & 4 deletions main_SVRG.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import sirf.STIR as STIR
from cil.optimisation.algorithms import Algorithm
from cil.optimisation.utilities import callbacks
from sirf.contrib.partitioner.partitioner import partition_indices
from sirf.contrib.partitioner import partitioner

import numpy as np
Expand Down Expand Up @@ -63,12 +62,12 @@ class Submission(Algorithm):
def __init__(
self,
data: Dataset,
initial_step_size: float = 1.0,
step_size_factor: float = 1.0, # multiplicative factor to increase / decrease default step sizes
num_subsets: int | None = None,
update_objective_interval: int | None = None,
complete_gradient_epochs: None | list[int] = None,
precond_update_epochs: None | list[int] = None,
precond_hessian_factor: float = 16.0,
precond_hessian_factor: float = 32.0,
verbose: bool = False,
**kwargs,
):
Expand Down Expand Up @@ -104,7 +103,8 @@ def __init__(
self.x = data.OSEM_image.clone()

self._update = 0
self._step_size = initial_step_size
self._step_size_factor = step_size_factor
self._step_size = self._step_size_factor * 2.0
self._subset_number_list = []
self._precond_hessian_factor = precond_hessian_factor

Expand Down Expand Up @@ -191,6 +191,19 @@ def __init__(
def epoch(self):
return self._update // self._num_subsets

def update_step_size(self):
if self.epoch <= 3:
self._step_size = self._step_size_factor * 2.0
elif self.epoch > 3 and self.epoch <= 8:
self._step_size = self._step_size_factor * 1.5
elif self.epoch > 8 and self.epoch <= 12:
self._step_size = self._step_size_factor * 1.0
else:
self._step_size = self._step_size_factor * 0.5

if self._verbose:
print(self._update, self.epoch, self._step_size)

def calc_precond(
self,
x: STIR.ImageData,
Expand Down Expand Up @@ -242,6 +255,9 @@ def update(self):
self._update % self._num_subsets == 0
) and self.epoch in self._precond_update_epochs

if self._update % self._num_subsets == 0:
self.update_step_size()

if update_precond:
if self._verbose:
print(f" {self._update}, updating preconditioner")
Expand Down
2 changes: 1 addition & 1 deletion test_petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,4 +418,4 @@ def test_petric(ds: int, num_iter: int, **kwargs):
)
else:
for i in range(4):
test_petric(ds=i, num_iter=300)
test_petric(ds=i, num_iter=200)

0 comments on commit 2cac36a

Please sign in to comment.