@@ -305,147 +305,147 @@ def _approx_diag_hessian(self, x: Array) -> Array:
305
305
return diag_hes
306
306
307
307
308
- def neighbor_difference_and_sum (
309
- x : Array , xp : ModuleType , padding : str = "edge"
310
- ) -> tuple [Array , Array ]:
311
- """get differences and sums with nearest neighbors for an n-dimensional array x
312
- using padding (by default in edge mode)
313
- a x.ndim*(3,) neighborhood around each element is used
314
- """
315
- x_padded = xp .pad (x , 1 , mode = padding )
316
-
317
- # number of nearest neighbors
318
- num_neigh = 3 ** x .ndim - 1
319
-
320
- # array for differences and sums with nearest neighbors
321
- d = xp .zeros ((num_neigh ,) + x .shape , dtype = x .dtype )
322
- s = xp .zeros ((num_neigh ,) + x .shape , dtype = x .dtype )
323
-
324
- for i , ind in enumerate (xp .ndindex (x .ndim * (3 ,))):
325
- if i != (num_neigh // 2 ):
326
- sl = []
327
- for j in ind :
328
- if j - 2 < 0 :
329
- sl .append (slice (j , j - 2 ))
330
- else :
331
- sl .append (slice (j , None ))
332
- sl = tuple (sl )
333
-
334
- if i < num_neigh // 2 :
335
- d [i ] = x - x_padded [sl ]
336
- s [i ] = x + x_padded [sl ]
337
- else :
338
- d [i - 1 ] = x - x_padded [sl ]
339
- s [i - 1 ] = x + x_padded [sl ]
340
-
341
- return d , s
342
-
343
-
344
- def neighbor_product (x : Array , xp : ModuleType , padding : str = "edge" ) -> Array :
345
- """get backward and forward neighbor products for each dimension of an array x
346
- using padding (by default in edge mode)
347
- """
348
- x_padded = xp .pad (x , 1 , mode = padding )
349
-
350
- # number of nearest neighbors
351
- num_neigh = 3 ** x .ndim - 1
352
-
353
- # array for differences and sums with nearest neighbors
354
- p = xp .zeros ((num_neigh ,) + x .shape , dtype = x .dtype )
355
-
356
- for i , ind in enumerate (xp .ndindex (x .ndim * (3 ,))):
357
- if i != (num_neigh // 2 ):
358
- sl = []
359
- for j in ind :
360
- if j - 2 < 0 :
361
- sl .append (slice (j , j - 2 ))
362
- else :
363
- sl .append (slice (j , None ))
364
- sl = tuple (sl )
365
-
366
- if i < num_neigh // 2 :
367
- p [i ] = x * x_padded [sl ]
368
- else :
369
- p [i - 1 ] = x * x_padded [sl ]
370
-
371
- return p
372
-
373
-
374
- class RDP (SmoothFunctionWithApproxHessian ):
375
- def __init__ (
376
- self ,
377
- in_shape : tuple [int , ...],
378
- xp : ModuleType ,
379
- dev : str ,
380
- eps : float | None = None ,
381
- gamma : float = 2.0 ,
382
- padding : str = "edge" ,
383
- ) -> None :
384
- self ._gamma = gamma
385
-
386
- if eps is None :
387
- self ._eps = xp .finfo (xp .float32 ).eps
388
- else :
389
- self ._eps = eps
390
-
391
- self ._padding = padding
392
-
393
- self ._weights = None
394
-
395
- super ().__init__ (in_shape = in_shape , xp = xp , dev = dev )
396
-
397
- @property
398
- def gamma (self ) -> float :
399
- return self ._gamma
400
-
401
- @property
402
- def eps (self ) -> float :
403
- return self ._eps
404
-
405
- @property
406
- def weights (self ) -> Array | None :
407
- return self ._weights
408
-
409
- @weights .setter
410
- def weights (self , weights : Array ) -> None :
411
- self ._weights = weights
412
-
413
- def _call (self , x : Array ) -> float :
414
-
415
- if float (self .xp .min (x )) < 0 :
416
- return self .xp .inf
417
-
418
- d , s = neighbor_difference_and_sum (x , self .xp , padding = self ._padding )
419
- phi = s + self .gamma * self .xp .abs (d ) + self .eps
420
-
421
- tmp = (d ** 2 ) / phi
422
-
423
- if self ._weights is not None :
424
- tmp *= self ._weights
425
-
426
- return float (self .xp .sum (tmp ))
427
-
428
- def _gradient (self , x : Array ) -> Array :
429
- d , s = neighbor_difference_and_sum (x , self .xp , padding = self ._padding )
430
- phi = s + self .gamma * self .xp .abs (d ) + self .eps
431
-
432
- tmp = d * (2 * phi - (d + self .gamma * self .xp .abs (d ))) / (phi ** 2 )
433
-
434
- if self ._weights is not None :
435
- tmp *= self ._weights
436
-
437
- return 2 * tmp .sum (axis = 0 )
438
-
439
- def _approx_diag_hessian (self , x : Array ) -> Array :
440
- d , s = neighbor_difference_and_sum (x , self .xp , padding = self ._padding )
441
- phi = s + self .gamma * self .xp .abs (d ) + self .eps
442
-
443
- tmp = ((s - d + self .eps ) ** 2 ) / (phi ** 3 )
444
-
445
- if self ._weights is not None :
446
- tmp *= self ._weights
447
-
448
- return 4 * tmp .sum (axis = 0 )
308
+ # def neighbor_difference_and_sum(
309
+ # x: Array, xp: ModuleType, padding: str = "edge"
310
+ # ) -> tuple[Array, Array]:
311
+ # """get differences and sums with nearest neighbors for an n-dimensional array x
312
+ # using padding (by default in edge mode)
313
+ # a x.ndim*(3,) neighborhood around each element is used
314
+ # """
315
+ # x_padded = xp.pad(x, 1, mode=padding)
316
+ #
317
+ # # number of nearest neighbors
318
+ # num_neigh = 3**x.ndim - 1
319
+ #
320
+ # # array for differences and sums with nearest neighbors
321
+ # d = xp.zeros((num_neigh,) + x.shape, dtype=x.dtype)
322
+ # s = xp.zeros((num_neigh,) + x.shape, dtype=x.dtype)
323
+ #
324
+ # for i, ind in enumerate(xp.ndindex(x.ndim * (3,))):
325
+ # if i != (num_neigh // 2):
326
+ # sl = []
327
+ # for j in ind:
328
+ # if j - 2 < 0:
329
+ # sl.append(slice(j, j - 2))
330
+ # else:
331
+ # sl.append(slice(j, None))
332
+ # sl = tuple(sl)
333
+ #
334
+ # if i < num_neigh // 2:
335
+ # d[i] = x - x_padded[sl]
336
+ # s[i] = x + x_padded[sl]
337
+ # else:
338
+ # d[i - 1] = x - x_padded[sl]
339
+ # s[i - 1] = x + x_padded[sl]
340
+ #
341
+ # return d, s
342
+ #
343
+ #
344
+ # def neighbor_product(x: Array, xp: ModuleType, padding: str = "edge") -> Array:
345
+ # """get backward and forward neighbor products for each dimension of an array x
346
+ # using padding (by default in edge mode)
347
+ # """
348
+ # x_padded = xp.pad(x, 1, mode=padding)
349
+ #
350
+ # # number of nearest neighbors
351
+ # num_neigh = 3**x.ndim - 1
352
+ #
353
+ # # array for differences and sums with nearest neighbors
354
+ # p = xp.zeros((num_neigh,) + x.shape, dtype=x.dtype)
355
+ #
356
+ # for i, ind in enumerate(xp.ndindex(x.ndim * (3,))):
357
+ # if i != (num_neigh // 2):
358
+ # sl = []
359
+ # for j in ind:
360
+ # if j - 2 < 0:
361
+ # sl.append(slice(j, j - 2))
362
+ # else:
363
+ # sl.append(slice(j, None))
364
+ # sl = tuple(sl)
365
+ #
366
+ # if i < num_neigh // 2:
367
+ # p[i] = x * x_padded[sl]
368
+ # else:
369
+ # p[i - 1] = x * x_padded[sl]
370
+ #
371
+ # return p
372
+ #
373
+ #
374
+ # class RDP(SmoothFunctionWithApproxHessian):
375
+ # def __init__(
376
+ # self,
377
+ # in_shape: tuple[int, ...],
378
+ # xp: ModuleType,
379
+ # dev: str,
380
+ # eps: float | None = None,
381
+ # gamma: float = 2.0,
382
+ # padding: str = "edge",
383
+ # ) -> None:
384
+ # self._gamma = gamma
385
+ #
386
+ # if eps is None:
387
+ # self._eps = xp.finfo(xp.float32).eps
388
+ # else:
389
+ # self._eps = eps
390
+ #
391
+ # self._padding = padding
392
+ #
393
+ # self._weights = None
394
+ #
395
+ # super().__init__(in_shape=in_shape, xp=xp, dev=dev)
396
+ #
397
+ # @property
398
+ # def gamma(self) -> float:
399
+ # return self._gamma
400
+ #
401
+ # @property
402
+ # def eps(self) -> float:
403
+ # return self._eps
404
+ #
405
+ # @property
406
+ # def weights(self) -> Array | None:
407
+ # return self._weights
408
+ #
409
+ # @weights.setter
410
+ # def weights(self, weights: Array) -> None:
411
+ # self._weights = weights
412
+ #
413
+ # def _call(self, x: Array) -> float:
414
+ #
415
+ # if float(self.xp.min(x)) < 0:
416
+ # return self.xp.inf
417
+ #
418
+ # d, s = neighbor_difference_and_sum(x, self.xp, padding=self._padding)
419
+ # phi = s + self.gamma * self.xp.abs(d) + self.eps
420
+ #
421
+ # tmp = (d**2) / phi
422
+ #
423
+ # if self._weights is not None:
424
+ # tmp *= self._weights
425
+ #
426
+ # return float(self.xp.sum(tmp))
427
+ #
428
+ # def _gradient(self, x: Array) -> Array:
429
+ # d, s = neighbor_difference_and_sum(x, self.xp, padding=self._padding)
430
+ # phi = s + self.gamma * self.xp.abs(d) + self.eps
431
+ #
432
+ # tmp = d * (2 * phi - (d + self.gamma * self.xp.abs(d))) / (phi**2)
433
+ #
434
+ # if self._weights is not None:
435
+ # tmp *= self._weights
436
+ #
437
+ # return 2 * tmp.sum(axis=0)
438
+ #
439
+ # def _approx_diag_hessian(self, x: Array) -> Array:
440
+ # d, s = neighbor_difference_and_sum(x, self.xp, padding=self._padding)
441
+ # phi = s + self.gamma * self.xp.abs(d) + self.eps
442
+ #
443
+ # tmp = ((s - d + self.eps) ** 2) / (phi**3)
444
+ #
445
+ # if self._weights is not None:
446
+ # tmp *= self._weights
447
+ #
448
+ # return 4 * tmp.sum(axis=0)
449
449
450
450
451
451
class L2DataFidelity (SmoothFunction ):
@@ -1129,14 +1129,14 @@ def split_fwd_model(
1129
1129
def rdp_preconditioner (
1130
1130
x : Array ,
1131
1131
adjoint_ones : Array ,
1132
- prior : SmoothFunctionWithApproxHessian ,
1132
+ prior ,
1133
1133
version : int = 1 ,
1134
1134
delta : float = 1e-6 ,
1135
1135
) -> Array :
1136
1136
if version == 1 :
1137
1137
precond = (x + delta ) / adjoint_ones
1138
1138
elif version == 2 :
1139
- precond = (x + delta ) / (adjoint_ones + prior .approx_diag_hessian (x ) * x )
1139
+ precond = (x + delta ) / (adjoint_ones + prior .diag_hessian (x ) * x )
1140
1140
else :
1141
1141
raise ValueError ("precond_version must be 1 or 2" )
1142
1142
0 commit comments