Skip to content

Commit a2518b7

Browse files
authored
Merge pull request #2072 from devitocodes/tens-transp
symbolics: evaluate transpose on elements of tensors in case of deriv…
2 parents 1d4b9ca + 2edeac6 commit a2518b7

File tree

7 files changed

+103
-70
lines changed

7 files changed

+103
-70
lines changed

devito/types/basic.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,16 @@ def __init_finalize__(self, *args, **kwargs):
677677
def doit(self, **hint):
678678
return self
679679

680+
def transpose(self, inner=True):
681+
new = super().transpose()
682+
if inner:
683+
return new.applyfunc(lambda x: getattr(x, 'T', x))
684+
return new
685+
686+
def adjoint(self, inner=True):
687+
# Real valued adjoint is transpose
688+
return self.transpose(inner=inner)
689+
680690
def _eval_matrix_mul(self, other):
681691
"""
682692
Copy paste from sympy to avoid explicit call to sympy.Add

examples/seismic/elastic/operators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def ForwardOperator(model, geometry, space_order=4, save=False, **kwargs):
6161
# Particle velocity
6262
eq_v = v.dt - b * div(tau)
6363
# Stress
64-
e = (grad(v.forward) + grad(v.forward).T)
64+
e = (grad(v.forward) + grad(v.forward).transpose(inner=False))
6565
eq_tau = tau.dt - lam * diag(div(v.forward)) - mu * e
6666

6767
u_v = Eq(v.forward, model.damp * solve(eq_v, v.forward))

examples/seismic/tutorials/06_elastic.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@
137137
"\n",
138138
"# First order elastic wave equation\n",
139139
"pde_v = v.dt - ro * div(tau)\n",
140-
"pde_tau = tau.dt - l * diag(div(v.forward)) - mu * (grad(v.forward) + grad(v.forward).T)\n",
140+
"pde_tau = tau.dt - l * diag(div(v.forward)) - mu * (grad(v.forward) + grad(v.forward).transpose(inner=False))\n",
141141
"# Time update\n",
142142
"u_v = Eq(v.forward, solve(pde_v, v.forward))\n",
143143
"u_t = Eq(tau.forward, solve(pde_tau, tau.forward))\n",
@@ -304,7 +304,7 @@
304304
"\n",
305305
"# First order elastic wave equation\n",
306306
"pde_v = v.dt - ro * div(tau)\n",
307-
"pde_tau = tau.dt - l * diag(div(v.forward)) - mu * (grad(v.forward) + grad(v.forward).T)\n",
307+
"pde_tau = tau.dt - l * diag(div(v.forward)) - mu * (grad(v.forward) + grad(v.forward).transpose(inner=False))\n",
308308
"# Time update\n",
309309
"u_v = Eq(v.forward, solve(pde_v, v.forward))\n",
310310
"u_t = Eq(tau.forward, solve(pde_tau, tau.forward))\n",
@@ -445,7 +445,7 @@
445445
},
446446
"hide_input": false,
447447
"kernelspec": {
448-
"display_name": "Python 3",
448+
"display_name": "Python 3 (ipykernel)",
449449
"language": "python",
450450
"name": "python3"
451451
},
@@ -459,7 +459,7 @@
459459
"name": "python",
460460
"nbconvert_exporter": "python",
461461
"pygments_lexer": "ipython3",
462-
"version": "3.9.9"
462+
"version": "3.9.16"
463463
},
464464
"latex_envs": {
465465
"LaTeX_envs_menu_present": true,
@@ -485,5 +485,5 @@
485485
}
486486
},
487487
"nbformat": 4,
488-
"nbformat_minor": 2
488+
"nbformat_minor": 4
489489
}

examples/seismic/tutorials/06_elastic_varying_parameters.ipynb

Lines changed: 39 additions & 42 deletions
Large diffs are not rendered by default.

examples/seismic/tutorials/09_viscoelastic.ipynb

Lines changed: 16 additions & 21 deletions
Large diffs are not rendered by default.

examples/seismic/viscoelastic/operators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def ForwardOperator(model, geometry, space_order=4, save=False, **kwargs):
4848
pde_v = v.dt - b * div(tau)
4949
u_v = Eq(v.forward, model.damp * solve(pde_v, v.forward))
5050
# Strain
51-
e = grad(v.forward) + grad(v.forward).T
51+
e = grad(v.forward) + grad(v.forward).transpose(inner=False)
5252

5353
# Stress equations
5454
pde_tau = tau.dt - r.forward - l * t_ep / t_s * diag(div(v.forward)) - \

tests/test_tensors.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,37 @@ def test_vector_transpose(func1):
149149
assert np.all([f1[i] == f2[i] for i in range(3)])
150150

151151

152+
@pytest.mark.parametrize('func1', [VectorFunction, VectorTimeFunction])
153+
def test_vector_transpose_deriv(func1):
154+
grid = Grid(tuple([5]*3))
155+
f1 = func1(name="f1", grid=grid)
156+
f2 = f1.dx.T
157+
assert all([f2[i] == f1[i].dx.T for i in range(3)])
158+
159+
160+
@pytest.mark.parametrize('func1', [TensorFunction, TensorTimeFunction])
161+
def test_tensor_transpose_deriv(func1):
162+
grid = Grid(tuple([5]*3))
163+
f1 = func1(name="f1", grid=grid)
164+
f2 = f1.dx.T
165+
assert np.all([f2[i, j] == f1[j, i].dx.T for i in range(3) for j in range(3)])
166+
167+
168+
@pytest.mark.parametrize('func1', [TensorFunction, TensorTimeFunction,
169+
VectorFunction, VectorTimeFunction])
170+
def test_transpose_vs_T(func1):
171+
grid = Grid(tuple([5]*3))
172+
f1 = func1(name="f1", grid=grid)
173+
f2 = f1.dx.T
174+
f3 = f1.dx.transpose(inner=True)
175+
f4 = f1.dx.transpose(inner=False)
176+
# inner=True is the same as T
177+
assert f3 == f2
178+
# inner=False doesn't tranpose inner derivatives
179+
for f4i, f2i in zip(f4, f2):
180+
assert f4i == f2i.T
181+
182+
152183
@pytest.mark.parametrize('func1', [TensorFunction, TensorTimeFunction,
153184
VectorFunction, VectorTimeFunction])
154185
def test_tensor_fd(func1):

0 commit comments

Comments
 (0)