Skip to content

Commit c207ded

Browse files
authored
Merge pull request #2754 from devitocodes/interp-order
api: add interpolation order api for staggered/off grid evaluation
2 parents ff62667 + 4c5272e commit c207ded

File tree

5 files changed

+74
-20
lines changed

5 files changed

+74
-20
lines changed

devito/finite_differences/differentiable.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class Differentiable(sympy.Expr, Evaluable):
3636
# operators to be used
3737
_op_priority = sympy.Expr._op_priority + 1.
3838

39-
__rkwargs__ = ('space_order', 'time_order', 'indices')
39+
__rkwargs__ = ('space_order', 'interp_order', 'time_order', 'indices')
4040

4141
@cached_property
4242
def _functions(self):
@@ -54,6 +54,12 @@ def space_order(self):
5454
return min([getattr(i, 'space_order', 100) or 100 for i in self._args_diff],
5555
default=100)
5656

57+
@cached_property
58+
def interp_order(self):
59+
# Default 2 is a reasonable default for interpolation
60+
return min([getattr(i, 'interp_order', 2) or 2 for i in self._args_diff],
61+
default=2)
62+
5763
@cached_property
5864
def time_order(self):
5965
# Default 100 is for "infinitely" differentiable
@@ -1129,7 +1135,7 @@ def _(expr, x0, **kwargs):
11291135
x0_expr = {d: v for d, v in x0.items() if v is not expr.indices_ref[d]}
11301136
if x0_expr:
11311137
dims = tuple((d, 0) for d in x0_expr)
1132-
fd_o = tuple([2]*len(dims))
1138+
fd_o = tuple([expr.interp_order]*len(dims))
11331139
return Derivative(expr, *dims, fd_order=fd_o, x0=x0_expr)
11341140
else:
11351141
return expr

devito/types/basic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,7 @@ def _evaluate(self, **kwargs):
10231023
if not self._grid_map:
10241024
return self
10251025

1026+
io = self.interp_order
10261027
# Base function
10271028
if self._avg_mode == 'harmonic':
10281029
retval = 1 / self.function
@@ -1031,7 +1032,7 @@ def _evaluate(self, **kwargs):
10311032

10321033
# Apply interpolation from inner most dim
10331034
for d, i in self._grid_map.items():
1034-
retval = retval.diff(d, deriv_order=0, fd_order=2, x0={d: i})
1035+
retval = retval.diff(d, deriv_order=0, fd_order=io, x0={d: i})
10351036

10361037
# Evaluate. Since we used `self.function` it will be on the grid when
10371038
# evaluate is called again within FD

devito/types/dense.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,10 @@ class Function(DiscreteFunction):
941941
discretization order (`o`) as well as the number of points on
942942
the left/right sides of a generic point of interest for each
943943
SpaceDimension.
944+
interp_order: int, optional, default=2
945+
Order of the interpolation scheme used to evaluate the Function at
946+
non-grid points (e.g., when using a Function as a parameter to be
947+
evaluated at a staggered location).
944948
shape : tuple of ints, optional
945949
Shape of the domain region in grid points. Only necessary if `grid`
946950
isn't given.
@@ -1014,7 +1018,7 @@ class Function(DiscreteFunction):
10141018
is_autopaddable = True
10151019

10161020
__rkwargs__ = (DiscreteFunction.__rkwargs__ +
1017-
('space_order', 'dimensions'))
1021+
('space_order', 'interp_order', 'dimensions'))
10181022

10191023
def _cache_meta(self):
10201024
# Attach additional metadata to self's cache entry
@@ -1036,6 +1040,16 @@ def __init_finalize__(self, *args, **kwargs):
10361040
else:
10371041
raise TypeError("Invalid `space_order`")
10381042

1043+
# Interpolation order
1044+
interp_order = kwargs.get('interp_order', 2)
1045+
if not is_integer(interp_order):
1046+
raise TypeError("`interp_order` must be an integer")
1047+
elif interp_order < 1:
1048+
raise ValueError("`interp_order` must be >= 2")
1049+
elif interp_order > self._space_order and self._space_order > 1:
1050+
raise ValueError("`interp_order` must be <= `space_order`")
1051+
self._interp_order = interp_order
1052+
10391053
# Acquire derivative shortcuts
10401054
if self is self.function:
10411055
self._fd = self.__fd_setup__()
@@ -1233,6 +1247,11 @@ def space_order(self):
12331247
"""The space order."""
12341248
return self._space_order
12351249

1250+
@property
1251+
def interp_order(self):
1252+
"""The interpolation order."""
1253+
return self._interp_order
1254+
12361255
def sum(self, p=None, dims=None):
12371256
"""
12381257
Generate a symbolic expression computing the sum of `p` points

examples/userapi/01_dsl.ipynb

Lines changed: 9 additions & 5 deletions
Large diffs are not rendered by default.

tests/test_differentiable.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import sympy
44
import pytest
5+
import numpy as np
56

67
from devito import Function, Grid, Differentiable, NODE
78
from devito.finite_differences.differentiable import (Add, Mul, Pow, diffify,
@@ -96,28 +97,51 @@ def sp_diff(a, b):
9697

9798

9899
@pytest.mark.parametrize('ndim', [1, 2, 3])
99-
def test_avg_mode(ndim):
100+
@pytest.mark.parametrize('io', [None, 2, 4])
101+
def test_avg_mode(ndim, io):
100102
grid = Grid([11]*ndim)
101-
v = Function(name='v', grid=grid, staggered=grid.dimensions)
102-
a0 = Function(name="a0", grid=grid)
103-
a = Function(name="a", grid=grid, parameter=True)
104-
b = Function(name="b", grid=grid, parameter=True, avg_mode='harmonic')
103+
v = Function(name='v', grid=grid, staggered=grid.dimensions, space_order=4)
104+
kw = {'space_order': 4}
105+
if io is not None:
106+
kw['interp_order'] = io
107+
else:
108+
io = 2 # Default value
109+
110+
with pytest.raises(ValueError):
111+
# interp_order > space_order
112+
Function(name="a", grid=grid, parameter=True, interp_order=8, space_order=4)
113+
with pytest.raises(ValueError):
114+
# interp_order < 1
115+
Function(name="a", grid=grid, parameter=True, interp_order=0, space_order=4)
116+
with pytest.raises(TypeError):
117+
# interp_order not int
118+
Function(name="a", grid=grid, parameter=True, interp_order=2.5, space_order=4)
119+
120+
a0 = Function(name="a0", grid=grid, **kw)
121+
a = Function(name="a", grid=grid, parameter=True, **kw)
122+
b = Function(name="b", grid=grid, parameter=True, avg_mode='harmonic', **kw)
105123

106124
a0_avg = a0._eval_at(v)
107-
a_avg = a._eval_at(v).evaluate
108-
b_avg = b._eval_at(v).evaluate
125+
a_avg = a._eval_at(v).evaluate.simplify()
126+
b_avg = b._eval_at(v).evaluate.simplify()
109127

110128
assert a0_avg == a0
111129

112130
# Indices around the point at the center of a cell
113-
all_shift = tuple(product(*[[0, 1] for _ in range(ndim)]))
131+
idx = list(range(-io//2 + 1, io//2 + 1))
132+
all_shift = tuple(product(*[idx for _ in range(ndim)]))
133+
coeffs = {2: [0.5, 0.5], 4: [-1/16, 9/16, 9/16, -1/16]}[io]
134+
vars = ['i', 'j', 'k'][:ndim]
135+
rule = ','.join(vars) + '->' + ''.join(vars)
136+
ndcoeffs = np.einsum(rule, *([coeffs]*ndim))
114137
args = [{d: d + i * d.spacing for d, i in zip(grid.dimensions, s)} for s in all_shift]
115138

116139
# Default is arithmetic average
117-
assert sympy.simplify(a_avg - 0.5**ndim * sum(a.subs(arg) for arg in args)) == 0
140+
expected = sum(c * a.subs(arg) for c, arg in zip(ndcoeffs.flatten(), args))
141+
assert sympy.simplify(a_avg - expected) == 0
118142

119143
# Harmonic average, h(a[.5]) = 1/(.5/a[0] + .5/a[1])
120-
expected = 1/(0.5**ndim * sum(1/b.subs(arg) for arg in args))
121-
assert sympy.simplify(1/b_avg.args[0] - expected) == 0
144+
expected = (sum(c / b.subs(arg) for c, arg in zip(ndcoeffs.flatten(), args)))
145+
assert sympy.simplify(b_avg.args[0] - expected) == 0
122146
assert isinstance(b_avg, SafeInv)
123147
assert b_avg.base == b

0 commit comments

Comments
 (0)