Skip to content

Commit

Permalink
add new RDP class
Browse files Browse the repository at this point in the history
  • Loading branch information
gschramm committed Aug 20, 2024
1 parent 868f604 commit b4cdb8b
Showing 1 changed file with 94 additions and 36 deletions.
130 changes: 94 additions & 36 deletions simulations/validate_rdp_gradient_and_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

# TODO: add symmetric weights

import cupy as xp
from utils import SmoothFunctionWithApproxHessian


def get_d_s(x):
"""get backward and forward differences and sums for each dimension of an array x
using "edge" padding to avoid boundary issues
def neighbor_difference_and_sum(x, xp, padding="edge"):
"""get backward and forward neighbor differences and sums for each dimension of an array x
using padding (by default in edge mode)
"""
x_padded = xp.pad(x, 1, mode="edge")
x_padded = xp.pad(x, 1, mode=padding)

d = xp.zeros((2 * x.ndim,) + x.shape, dtype=x.dtype)
s = xp.zeros((2 * x.ndim,) + x.shape, dtype=x.dtype)
Expand All @@ -34,66 +34,124 @@ def get_d_s(x):
return d, s


def rdp(x, gamma=2.0, eps=1e-1):
d, s = get_d_s(x)
phi = s + gamma * xp.abs(d) + eps
class RDP(SmoothFunctionWithApproxHessian):
def __init__(
self,
in_shape,
xp,
dev,
eps: float | None = None,
gamma: float = 2.0,
padding: str = "edge",
) -> None:
self._gamma = gamma

tmp = (d**2) / phi
if eps is None:
self._eps = xp.finfo(xp.float32).eps
else:
self._eps = eps

return float(tmp.sum())
self._padding = padding

self._weights = None

def rdp_grad(x, gamma=2.0, eps=1e-1):
d, s = get_d_s(x)
phi = s + gamma * xp.abs(d) + eps
super().__init__(in_shape=in_shape, xp=xp, dev=dev)

tmp = d * (2 * phi - (d + gamma * xp.abs(d))) / (phi**2)
@property
def gamma(self) -> float:
return self._gamma

return 2 * tmp.sum(axis=0)
@property
def eps(self) -> float:
return self._eps

@property
def weights(self):
return self._weights

def rdp_diag_hess(x, gamma=2.0, eps=1e-1):
d, s = get_d_s(x)
phi = s + gamma * xp.abs(d) + eps
@weights.setter
def weights(self, weights) -> None:
self._weights = weights

tmp = ((s - d + eps) ** 2) / (phi**3)
def _call(self, x) -> float:

return 4 * tmp.sum(axis=0)
if float(self.xp.min(x)) < 0:
return self.xp.inf

d, s = neighbor_difference_and_sum(x, self.xp, padding=self._padding)
phi = s + self.gamma * self.xp.abs(d) + self.eps

tmp = (d**2) / phi

if self._weights is not None:
tmp *= self._weights

return float(self.xp.sum(tmp))

def _gradient(self, x):
d, s = neighbor_difference_and_sum(x, self.xp, padding=self._padding)
phi = s + self.gamma * self.xp.abs(d) + self.eps

tmp = d * (2 * phi - (d + self.gamma * self.xp.abs(d))) / (phi**2)

if self._weights is not None:
tmp *= self._weights

return 2 * tmp.sum(axis=0)

def _approx_diag_hessian(self, x):
d, s = neighbor_difference_and_sum(x, self.xp, padding=self._padding)
phi = s + self.gamma * self.xp.abs(d) + self.eps

tmp = ((s - d + self.eps) ** 2) / (phi**3)

if self._weights is not None:
tmp *= self._weights

return 4 * tmp.sum(axis=0)


if __name__ == "__main__":
xp.random.seed(0)
x = xp.random.rand(5, 6) + 1
import array_api_compat.numpy as np

np.set_printoptions(precision=4)

np.random.seed(0)
x = np.random.rand(7, 8) + 1

pad_mode = "edge"

weight_image = np.random.rand(*x.shape)
_, weights = neighbor_difference_and_sum(weight_image, np, padding=pad_mode)

gamma = 2.0
eps = 0.1
rdp = RDP(in_shape=x.shape, xp=np, dev="cpu", gamma=5.0, eps=0.01)

f = rdp(x, gamma=gamma, eps=eps)
g = rdp_grad(x, gamma=gamma, eps=eps)
h = rdp_diag_hess(x, gamma=gamma, eps=eps)
f = rdp(x)
g = rdp.gradient(x)
h = rdp.approx_diag_hessian(x)

e = 1e-6
g_num = xp.zeros_like(x)
h_num = xp.zeros_like(x)
e = 1e-5
g_num = np.zeros_like(x)
h_num = np.zeros_like(x)

for index in xp.ndindex(x.shape):
for index in np.ndindex(x.shape):
xxp = x.copy()
xxp[index] += e

xxm = x.copy()
xxm[index] -= e

fp = rdp(xxp, gamma=gamma, eps=eps)
fm = rdp(xxm, gamma=gamma, eps=eps)
fp = rdp(xxp)
fm = rdp(xxm)

g_num[index] = (fp - fm) / (2 * e)

# numerical evaliation of the diagonal Hessian
h_num[index] = (fp - 2 * f + fm) / (e**2)

print("\ngradient / numerical gradient")
print("\ngradient / numerical gradient - should be 1 for all voxels")
print(g / g_num)
print("\ndiag hess / numerical diag hess")
print("\ndiag hess / numerical diag hess - should be 1 for all but edge voxels")
print(h / h_num)

assert xp.all(xp.isclose(g, g_num))
assert np.all(np.isclose(g, g_num))

0 comments on commit b4cdb8b

Please sign in to comment.