2
2
3
3
# TODO: add symmetric weights
4
4
5
- import cupy as xp
5
+ from utils import SmoothFunctionWithApproxHessian
6
6
7
7
8
- def get_d_s ( x ):
9
- """get backward and forward differences and sums for each dimension of an array x
10
- using "edge" padding to avoid boundary issues
8
+ def neighbor_difference_and_sum ( x , xp , padding = "edge" ):
9
+ """get backward and forward neighbor differences and sums for each dimension of an array x
10
+ using padding (by default in edge mode)
11
11
"""
12
- x_padded = xp .pad (x , 1 , mode = "edge" )
12
+ x_padded = xp .pad (x , 1 , mode = padding )
13
13
14
14
d = xp .zeros ((2 * x .ndim ,) + x .shape , dtype = x .dtype )
15
15
s = xp .zeros ((2 * x .ndim ,) + x .shape , dtype = x .dtype )
@@ -34,66 +34,124 @@ def get_d_s(x):
34
34
return d , s
35
35
36
36
37
- def rdp (x , gamma = 2.0 , eps = 1e-1 ):
38
- d , s = get_d_s (x )
39
- phi = s + gamma * xp .abs (d ) + eps
37
+ class RDP (SmoothFunctionWithApproxHessian ):
38
+ def __init__ (
39
+ self ,
40
+ in_shape ,
41
+ xp ,
42
+ dev ,
43
+ eps : float | None = None ,
44
+ gamma : float = 2.0 ,
45
+ padding : str = "edge" ,
46
+ ) -> None :
47
+ self ._gamma = gamma
40
48
41
- tmp = (d ** 2 ) / phi
49
+ if eps is None :
50
+ self ._eps = xp .finfo (xp .float32 ).eps
51
+ else :
52
+ self ._eps = eps
42
53
43
- return float ( tmp . sum ())
54
+ self . _padding = padding
44
55
56
+ self ._weights = None
45
57
46
- def rdp_grad (x , gamma = 2.0 , eps = 1e-1 ):
47
- d , s = get_d_s (x )
48
- phi = s + gamma * xp .abs (d ) + eps
58
+ super ().__init__ (in_shape = in_shape , xp = xp , dev = dev )
49
59
50
- tmp = d * (2 * phi - (d + gamma * xp .abs (d ))) / (phi ** 2 )
60
+ @property
61
+ def gamma (self ) -> float :
62
+ return self ._gamma
51
63
52
- return 2 * tmp .sum (axis = 0 )
64
+ @property
65
+ def eps (self ) -> float :
66
+ return self ._eps
53
67
68
+ @property
69
+ def weights (self ):
70
+ return self ._weights
54
71
55
- def rdp_diag_hess ( x , gamma = 2.0 , eps = 1e-1 ):
56
- d , s = get_d_s ( x )
57
- phi = s + gamma * xp . abs ( d ) + eps
72
+ @ weights . setter
73
+ def weights ( self , weights ) -> None :
74
+ self . _weights = weights
58
75
59
- tmp = (( s - d + eps ) ** 2 ) / ( phi ** 3 )
76
+ def _call ( self , x ) -> float :
60
77
61
- return 4 * tmp .sum (axis = 0 )
78
+ if float (self .xp .min (x )) < 0 :
79
+ return self .xp .inf
80
+
81
+ d , s = neighbor_difference_and_sum (x , self .xp , padding = self ._padding )
82
+ phi = s + self .gamma * self .xp .abs (d ) + self .eps
83
+
84
+ tmp = (d ** 2 ) / phi
85
+
86
+ if self ._weights is not None :
87
+ tmp *= self ._weights
88
+
89
+ return float (self .xp .sum (tmp ))
90
+
91
+ def _gradient (self , x ):
92
+ d , s = neighbor_difference_and_sum (x , self .xp , padding = self ._padding )
93
+ phi = s + self .gamma * self .xp .abs (d ) + self .eps
94
+
95
+ tmp = d * (2 * phi - (d + self .gamma * self .xp .abs (d ))) / (phi ** 2 )
96
+
97
+ if self ._weights is not None :
98
+ tmp *= self ._weights
99
+
100
+ return 2 * tmp .sum (axis = 0 )
101
+
102
+ def _approx_diag_hessian (self , x ):
103
+ d , s = neighbor_difference_and_sum (x , self .xp , padding = self ._padding )
104
+ phi = s + self .gamma * self .xp .abs (d ) + self .eps
105
+
106
+ tmp = ((s - d + self .eps ) ** 2 ) / (phi ** 3 )
107
+
108
+ if self ._weights is not None :
109
+ tmp *= self ._weights
110
+
111
+ return 4 * tmp .sum (axis = 0 )
62
112
63
113
64
114
if __name__ == "__main__" :
65
- xp .random .seed (0 )
66
- x = xp .random .rand (5 , 6 ) + 1
115
+ import array_api_compat .numpy as np
116
+
117
+ np .set_printoptions (precision = 4 )
118
+
119
+ np .random .seed (0 )
120
+ x = np .random .rand (7 , 8 ) + 1
121
+
122
+ pad_mode = "edge"
123
+
124
+ weight_image = np .random .rand (* x .shape )
125
+ _ , weights = neighbor_difference_and_sum (weight_image , np , padding = pad_mode )
67
126
68
- gamma = 2.0
69
- eps = 0.1
127
+ rdp = RDP (in_shape = x .shape , xp = np , dev = "cpu" , gamma = 5.0 , eps = 0.01 )
70
128
71
- f = rdp (x , gamma = gamma , eps = eps )
72
- g = rdp_grad ( x , gamma = gamma , eps = eps )
73
- h = rdp_diag_hess ( x , gamma = gamma , eps = eps )
129
+ f = rdp (x )
130
+ g = rdp . gradient ( x )
131
+ h = rdp . approx_diag_hessian ( x )
74
132
75
- e = 1e-6
76
- g_num = xp .zeros_like (x )
77
- h_num = xp .zeros_like (x )
133
+ e = 1e-5
134
+ g_num = np .zeros_like (x )
135
+ h_num = np .zeros_like (x )
78
136
79
- for index in xp .ndindex (x .shape ):
137
+ for index in np .ndindex (x .shape ):
80
138
xxp = x .copy ()
81
139
xxp [index ] += e
82
140
83
141
xxm = x .copy ()
84
142
xxm [index ] -= e
85
143
86
- fp = rdp (xxp , gamma = gamma , eps = eps )
87
- fm = rdp (xxm , gamma = gamma , eps = eps )
144
+ fp = rdp (xxp )
145
+ fm = rdp (xxm )
88
146
89
147
g_num [index ] = (fp - fm ) / (2 * e )
90
148
91
149
# numerical evaliation of the diagonal Hessian
92
150
h_num [index ] = (fp - 2 * f + fm ) / (e ** 2 )
93
151
94
- print ("\n gradient / numerical gradient" )
152
+ print ("\n gradient / numerical gradient - should be 1 for all voxels " )
95
153
print (g / g_num )
96
- print ("\n diag hess / numerical diag hess" )
154
+ print ("\n diag hess / numerical diag hess - should be 1 for all but edge voxels " )
97
155
print (h / h_num )
98
156
99
- assert xp .all (xp .isclose (g , g_num ))
157
+ assert np .all (np .isclose (g , g_num ))
0 commit comments