Skip to content

Commit

Permalink
add test for diag of Hessian of RDP
Browse files Browse the repository at this point in the history
  • Loading branch information
gschramm committed Aug 20, 2024
1 parent 0f2806b commit 868f604
Showing 1 changed file with 31 additions and 61 deletions.
92 changes: 31 additions & 61 deletions simulations/validate_rdp_gradient_and_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def get_d_s(x):
"""get backward and forward differences and sums for each dimension of an array x
using "edge" padding
using "edge" padding to avoid boundary issues
"""
x_padded = xp.pad(x, 1, mode="edge")

Expand Down Expand Up @@ -52,78 +52,48 @@ def rdp_grad(x, gamma=2.0, eps=1e-1):
return 2 * tmp.sum(axis=0)


# def rdp_potential(xj, xk, gamma=2.0, eps=1e-1):
# d = xj - xk
# s = xj + xk
# phi = s + gamma * xp.abs(d) + eps
#
# c = (d**2) / phi
# g = 2 * d * (2 * phi - (d + gamma * xp.abs(d))) / (phi**2)
#
# return c, g
#
#
# def rdp2d(x, gamma=2.0, eps=1e-1):
#
# n0, n1 = x.shape
#
# val = 0.0
# grad = xp.zeros_like(x)
#
# for j in range(n0):
# for k in range(n1):
# jm = j - 1
# jp = j + 1
# km = k - 1
# kp = k + 1
#
# x0 = x[j, k]
#
# if jm >= 0:
# c, g = rdp_potential(x0, x[jm, k], gamma=gamma, eps=eps)
# val += c
# grad[j, k] += g
# if jp < n0:
# c, g = rdp_potential(x0, x[jp, k], gamma=gamma, eps=eps)
# val += c
# grad[j, k] += g
# if km >= 0:
# c, g = rdp_potential(x0, x[j, km], gamma=gamma, eps=eps)
# val += c
# grad[j, k] += g
# if kp < n1:
# c, g = rdp_potential(x0, x[j, kp], gamma=gamma, eps=eps)
# val += c
# grad[j, k] += g
#
# return val, grad
def rdp_diag_hess(x, gamma=2.0, eps=1e-1):
d, s = get_d_s(x)
phi = s + gamma * xp.abs(d) + eps

tmp = ((s - d + eps) ** 2) / (phi**3)

return 4 * tmp.sum(axis=0)


if __name__ == "__main__":
xx = xp.arange(2 * 3 * 4, dtype=xp.float64).reshape(2, 3, 4) + 1
xp.random.seed(0)
x = xp.random.rand(5, 6) + 1

gamma = 2.0
eps = 0.1

h = 1e-7
f = rdp(x, gamma=gamma, eps=eps)
g = rdp_grad(x, gamma=gamma, eps=eps)
h = rdp_diag_hess(x, gamma=gamma, eps=eps)

e = 1e-6
g_num = xp.zeros_like(x)
h_num = xp.zeros_like(x)

f = rdp(xx, gamma=gamma, eps=eps)
g = rdp_grad(xx, gamma=gamma, eps=eps)
for index in xp.ndindex(x.shape):
xxp = x.copy()
xxp[index] += e

g_num = xp.zeros_like(xx)
xxm = x.copy()
xxm[index] -= e

for i in range(xx.shape[0]):
for j in range(xx.shape[1]):
for k in range(xx.shape[2]):
xxp = xx.copy()
xxp[i, j, k] += h
fp = rdp(xxp, gamma=gamma, eps=eps)
fm = rdp(xxm, gamma=gamma, eps=eps)

xxm = xx.copy()
xxm[i, j, k] -= h
g_num[index] = (fp - fm) / (2 * e)

fp = rdp(xxp, gamma=gamma, eps=eps)
fm = rdp(xxm, gamma=gamma, eps=eps)
# numerical evaliation of the diagonal Hessian
h_num[index] = (fp - 2 * f + fm) / (e**2)

g_num[i, j, k] = (fp - fm) / (2 * h)
print("\ngradient / numerical gradient")
print(g / g_num)
print("\ndiag hess / numerical diag hess")
print(h / h_num)

assert xp.all(xp.isclose(g, g_num))

0 comments on commit 868f604

Please sign in to comment.