Skip to content

Commit b4cdb8b

Browse files
committed
add new RDP class
1 parent 868f604 commit b4cdb8b

File tree

1 file changed

+94
-36
lines changed

1 file changed

+94
-36
lines changed

simulations/validate_rdp_gradient_and_hessian.py

+94-36
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22

33
# TODO: add symmetric weights
44

5-
import cupy as xp
5+
from utils import SmoothFunctionWithApproxHessian
66

77

8-
def get_d_s(x):
9-
"""get backward and forward differences and sums for each dimension of an array x
10-
using "edge" padding to avoid boundary issues
8+
def neighbor_difference_and_sum(x, xp, padding="edge"):
9+
"""get backward and forward neighbor differences and sums for each dimension of an array x
10+
using padding (by default in edge mode)
1111
"""
12-
x_padded = xp.pad(x, 1, mode="edge")
12+
x_padded = xp.pad(x, 1, mode=padding)
1313

1414
d = xp.zeros((2 * x.ndim,) + x.shape, dtype=x.dtype)
1515
s = xp.zeros((2 * x.ndim,) + x.shape, dtype=x.dtype)
@@ -34,66 +34,124 @@ def get_d_s(x):
3434
return d, s
3535

3636

37-
def rdp(x, gamma=2.0, eps=1e-1):
38-
d, s = get_d_s(x)
39-
phi = s + gamma * xp.abs(d) + eps
37+
class RDP(SmoothFunctionWithApproxHessian):
38+
def __init__(
39+
self,
40+
in_shape,
41+
xp,
42+
dev,
43+
eps: float | None = None,
44+
gamma: float = 2.0,
45+
padding: str = "edge",
46+
) -> None:
47+
self._gamma = gamma
4048

41-
tmp = (d**2) / phi
49+
if eps is None:
50+
self._eps = xp.finfo(xp.float32).eps
51+
else:
52+
self._eps = eps
4253

43-
return float(tmp.sum())
54+
self._padding = padding
4455

56+
self._weights = None
4557

46-
def rdp_grad(x, gamma=2.0, eps=1e-1):
47-
d, s = get_d_s(x)
48-
phi = s + gamma * xp.abs(d) + eps
58+
super().__init__(in_shape=in_shape, xp=xp, dev=dev)
4959

50-
tmp = d * (2 * phi - (d + gamma * xp.abs(d))) / (phi**2)
60+
@property
61+
def gamma(self) -> float:
62+
return self._gamma
5163

52-
return 2 * tmp.sum(axis=0)
64+
@property
65+
def eps(self) -> float:
66+
return self._eps
5367

68+
@property
69+
def weights(self):
70+
return self._weights
5471

55-
def rdp_diag_hess(x, gamma=2.0, eps=1e-1):
56-
d, s = get_d_s(x)
57-
phi = s + gamma * xp.abs(d) + eps
72+
@weights.setter
73+
def weights(self, weights) -> None:
74+
self._weights = weights
5875

59-
tmp = ((s - d + eps) ** 2) / (phi**3)
76+
def _call(self, x) -> float:
6077

61-
return 4 * tmp.sum(axis=0)
78+
if float(self.xp.min(x)) < 0:
79+
return self.xp.inf
80+
81+
d, s = neighbor_difference_and_sum(x, self.xp, padding=self._padding)
82+
phi = s + self.gamma * self.xp.abs(d) + self.eps
83+
84+
tmp = (d**2) / phi
85+
86+
if self._weights is not None:
87+
tmp *= self._weights
88+
89+
return float(self.xp.sum(tmp))
90+
91+
def _gradient(self, x):
92+
d, s = neighbor_difference_and_sum(x, self.xp, padding=self._padding)
93+
phi = s + self.gamma * self.xp.abs(d) + self.eps
94+
95+
tmp = d * (2 * phi - (d + self.gamma * self.xp.abs(d))) / (phi**2)
96+
97+
if self._weights is not None:
98+
tmp *= self._weights
99+
100+
return 2 * tmp.sum(axis=0)
101+
102+
def _approx_diag_hessian(self, x):
103+
d, s = neighbor_difference_and_sum(x, self.xp, padding=self._padding)
104+
phi = s + self.gamma * self.xp.abs(d) + self.eps
105+
106+
tmp = ((s - d + self.eps) ** 2) / (phi**3)
107+
108+
if self._weights is not None:
109+
tmp *= self._weights
110+
111+
return 4 * tmp.sum(axis=0)
62112

63113

64114
if __name__ == "__main__":
65-
xp.random.seed(0)
66-
x = xp.random.rand(5, 6) + 1
115+
import array_api_compat.numpy as np
116+
117+
np.set_printoptions(precision=4)
118+
119+
np.random.seed(0)
120+
x = np.random.rand(7, 8) + 1
121+
122+
pad_mode = "edge"
123+
124+
weight_image = np.random.rand(*x.shape)
125+
_, weights = neighbor_difference_and_sum(weight_image, np, padding=pad_mode)
67126

68-
gamma = 2.0
69-
eps = 0.1
127+
rdp = RDP(in_shape=x.shape, xp=np, dev="cpu", gamma=5.0, eps=0.01)
70128

71-
f = rdp(x, gamma=gamma, eps=eps)
72-
g = rdp_grad(x, gamma=gamma, eps=eps)
73-
h = rdp_diag_hess(x, gamma=gamma, eps=eps)
129+
f = rdp(x)
130+
g = rdp.gradient(x)
131+
h = rdp.approx_diag_hessian(x)
74132

75-
e = 1e-6
76-
g_num = xp.zeros_like(x)
77-
h_num = xp.zeros_like(x)
133+
e = 1e-5
134+
g_num = np.zeros_like(x)
135+
h_num = np.zeros_like(x)
78136

79-
for index in xp.ndindex(x.shape):
137+
for index in np.ndindex(x.shape):
80138
xxp = x.copy()
81139
xxp[index] += e
82140

83141
xxm = x.copy()
84142
xxm[index] -= e
85143

86-
fp = rdp(xxp, gamma=gamma, eps=eps)
87-
fm = rdp(xxm, gamma=gamma, eps=eps)
144+
fp = rdp(xxp)
145+
fm = rdp(xxm)
88146

89147
g_num[index] = (fp - fm) / (2 * e)
90148

91149
# numerical evaliation of the diagonal Hessian
92150
h_num[index] = (fp - 2 * f + fm) / (e**2)
93151

94-
print("\ngradient / numerical gradient")
152+
print("\ngradient / numerical gradient - should be 1 for all voxels")
95153
print(g / g_num)
96-
print("\ndiag hess / numerical diag hess")
154+
print("\ndiag hess / numerical diag hess - should be 1 for all but edge voxels")
97155
print(h / h_num)
98156

99-
assert xp.all(xp.isclose(g, g_num))
157+
assert np.all(np.isclose(g, g_num))

0 commit comments

Comments
 (0)