Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
gschramm committed Aug 27, 2024
1 parent e02695b commit ceafe95
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 26 deletions.
9 changes: 4 additions & 5 deletions main_SVRG.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(

# setup the initial image as a slightly smoothed version of the OSEM image
self.x = data.OSEM_image.clone()
# sig = 5.0 / (2.35 * np.array(data.OSEM_image.spacing))
# sig = 3.0 / (2.35 * np.array(data.OSEM_image.spacing))
# self.x.fill(gaussian_filter(self.x.as_array(), sig))

self._verbose = verbose
Expand Down Expand Up @@ -125,7 +125,7 @@ def __init__(
self._complete_gradient_epochs = complete_gradient_epochs

if precond_update_epochs is None:
self._precond_update_epochs: list[int] = [1]
self._precond_update_epochs: list[int] = [1, 2, 3]
else:
self._precond_update_epochs = precond_update_epochs

Expand Down Expand Up @@ -163,10 +163,9 @@ def calc_precond(
) -> STIR.ImageData:

delta = delta_rel * x.max()
prior_diag_hess = x.get_uniform_copy(0)

prior_diag_hess = 0.0 * x
tmp = self._python_prior.diag_hessian(x.as_array())
prior_diag_hess.fill(1 / (tmp + delta_rel))
prior_diag_hess.fill(self._python_prior.diag_hessian(x.as_array()))

return (x + delta) / (self._adjoint_ones + prior_diag_hess * x)

Expand Down
48 changes: 28 additions & 20 deletions test_petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,14 @@ def test_petric(
/ f"{formatted_datetime}_ss_{step_size}_n_{num_iter}_subs_{num_subsets}"
)
metrics = [MetricsWithTimeout(outdir=outdir)]
elif ds == 3:
srcdir = SRCDIR / "Siemens_mMR_ACR"
outdir = (
OUTDIR
/ "Siemens_mMR_ACR"
/ f"{formatted_datetime}_ss_{step_size}_n_{num_iter}_subs_{num_subsets}"
)
metrics = [MetricsWithTimeout(outdir=outdir)]
else:
raise ValueError(f"Unknown data set {ds}")

Expand Down Expand Up @@ -432,10 +440,10 @@ def test_petric(
#######################################################

algo = Submission(
data,
data=data,
initial_step_size=step_size,
num_subsets=num_subsets,
update_objective_interval=metric_period,
num_subsets=num_subsets,
complete_gradient_epochs=complete_gradient_epochs,
precond_update_epochs=precond_update_epochs,
)
Expand Down Expand Up @@ -481,30 +489,30 @@ def test_petric(
precond_update_epochs=precond_update_epochs,
)
else:
for step_size in [0.3, 0.5, 1.0]:
for step_size in [0.3, 1.0]:
# data set 0 "mMR_NEMA_IQ" - num views 252
for num_subsets in [9]:
for num_subsets in [14, 28]:
test_petric(
step_size,
0,
200,
num_subsets,
int(args["--metric_period"]),
step_size=step_size,
ds=0,
num_iter=200,
num_subsets=num_subsets,
metric_period=num_subsets,
complete_gradient_epochs=complete_gradient_epochs,
precond_update_epochs=precond_update_epochs,
)

## data set 1 "neuro LF" - num views 128
# for num_subsets in [8, 32, 64]:
# test_petric(
# step_size,
# 1,
# 200,
# num_subsets,
# int(args["--metric_period"]),
# complete_gradient_epochs=complete_gradient_epochs,
# precond_update_epochs=precond_update_epochs,
# )
## data set 1 "neuro LF" - num views 128
# for num_subsets in [8, 16]:
# test_petric(
# step_size,
# 1,
# 200,
# num_subsets,
# int(args["--metric_period"]),
# complete_gradient_epochs=complete_gradient_epochs,
# precond_update_epochs=precond_update_epochs,
# )

## data set 2 "vision" - num views 50
# for num_subsets in [5, 10, 25]:
Expand Down
7 changes: 6 additions & 1 deletion test_rdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,11 @@ def get_image(fname):
OUTDIR / "Vision600_thorax",
[MetricsWithTimeout(outdir=OUTDIR / "Vision600_thorax")],
),
(
SRCDIR / "Siemens_mMR_ACR",
OUTDIR / "Siemens_mMR_ACR",
[MetricsWithTimeout(outdir=OUTDIR / "Siemens_mMR_ACR")],
),
]
else:
log.warning("Source directory does not exist: %s", SRCDIR)
Expand All @@ -375,7 +380,7 @@ def get_image(fname):
import array_api_compat.numpy as xp
from scipy.ndimage import gaussian_filter

for srcdir, outdir, metrics in data_dirs_metrics[0:1]:
for srcdir, outdir, metrics in data_dirs_metrics[3:4]:
print(srcdir)
data = get_data(srcdir=srcdir, outdir=outdir)

Expand Down

0 comments on commit ceafe95

Please sign in to comment.