Skip to content

Commit

Permalink
add script to validate RDP gradient and hessian
Browse files Browse the repository at this point in the history
  • Loading branch information
gschramm committed Aug 20, 2024
1 parent fed9aa0 commit 0f2806b
Showing 1 changed file with 129 additions and 0 deletions.
129 changes: 129 additions & 0 deletions simulations/validate_rdp_gradient_and_hessian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""validate gradient and diagonal Hessian of RDP"""

# TODO: add symmetric weights

import cupy as xp


def get_d_s(x):
"""get backward and forward differences and sums for each dimension of an array x
using "edge" padding
"""
x_padded = xp.pad(x, 1, mode="edge")

d = xp.zeros((2 * x.ndim,) + x.shape, dtype=x.dtype)
s = xp.zeros((2 * x.ndim,) + x.shape, dtype=x.dtype)

for i in range(x.ndim):
# diff / sum with "backward" neighbor
sl = x.ndim * [slice(1, -1)]
sl[i] = slice(0, -2)
sl = tuple(sl)

d[2 * i, ...] = x - x_padded[sl]
s[2 * i, ...] = x + x_padded[sl]

# diff / sum with "forward" neighbor
sl = x.ndim * [slice(1, -1)]
sl[i] = slice(2, None)
sl = tuple(sl)

d[2 * i + 1, ...] = x - x_padded[sl]
s[2 * i + 1, ...] = x + x_padded[sl]

return d, s


def rdp(x, gamma=2.0, eps=1e-1):
d, s = get_d_s(x)
phi = s + gamma * xp.abs(d) + eps

tmp = (d**2) / phi

return float(tmp.sum())


def rdp_grad(x, gamma=2.0, eps=1e-1):
d, s = get_d_s(x)
phi = s + gamma * xp.abs(d) + eps

tmp = d * (2 * phi - (d + gamma * xp.abs(d))) / (phi**2)

return 2 * tmp.sum(axis=0)


# def rdp_potential(xj, xk, gamma=2.0, eps=1e-1):
# d = xj - xk
# s = xj + xk
# phi = s + gamma * xp.abs(d) + eps
#
# c = (d**2) / phi
# g = 2 * d * (2 * phi - (d + gamma * xp.abs(d))) / (phi**2)
#
# return c, g
#
#
# def rdp2d(x, gamma=2.0, eps=1e-1):
#
# n0, n1 = x.shape
#
# val = 0.0
# grad = xp.zeros_like(x)
#
# for j in range(n0):
# for k in range(n1):
# jm = j - 1
# jp = j + 1
# km = k - 1
# kp = k + 1
#
# x0 = x[j, k]
#
# if jm >= 0:
# c, g = rdp_potential(x0, x[jm, k], gamma=gamma, eps=eps)
# val += c
# grad[j, k] += g
# if jp < n0:
# c, g = rdp_potential(x0, x[jp, k], gamma=gamma, eps=eps)
# val += c
# grad[j, k] += g
# if km >= 0:
# c, g = rdp_potential(x0, x[j, km], gamma=gamma, eps=eps)
# val += c
# grad[j, k] += g
# if kp < n1:
# c, g = rdp_potential(x0, x[j, kp], gamma=gamma, eps=eps)
# val += c
# grad[j, k] += g
#
# return val, grad


if __name__ == "__main__":
xx = xp.arange(2 * 3 * 4, dtype=xp.float64).reshape(2, 3, 4) + 1

gamma = 2.0
eps = 0.1

h = 1e-7

f = rdp(xx, gamma=gamma, eps=eps)
g = rdp_grad(xx, gamma=gamma, eps=eps)

g_num = xp.zeros_like(xx)

for i in range(xx.shape[0]):
for j in range(xx.shape[1]):
for k in range(xx.shape[2]):
xxp = xx.copy()
xxp[i, j, k] += h

xxm = xx.copy()
xxm[i, j, k] -= h

fp = rdp(xxp, gamma=gamma, eps=eps)
fm = rdp(xxm, gamma=gamma, eps=eps)

g_num[i, j, k] = (fp - fm) / (2 * h)

assert xp.all(xp.isclose(g, g_num))

0 comments on commit 0f2806b

Please sign in to comment.