Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 53 additions & 9 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Hashable, Literal, Callable, Iterable
from dataclasses import asdict, dataclass
from numbers import Number
from weakref import WeakValueDictionary

from ufl.algorithms import extract_arguments, replace
from ufl.domain import extract_unique_domain
Expand Down Expand Up @@ -64,6 +65,8 @@
"Interpolator"
)

_vom_cache = WeakValueDictionary()


@dataclass(kw_only=True)
class InterpolateOptions:
Expand Down Expand Up @@ -462,15 +465,45 @@ def __init__(self, expr: Interpolate):
# scalar fiat/finat element
self.dest_element = dest_element

def _get_symbolic_expressions(self) -> tuple[Interpolate, Interpolate]:
"""Return the symbolic ``Interpolate`` expressions for point evaluation and
re-ordering into the input-ordering VertexOnlyMesh.
@staticmethod
def _vom_cache_key(target_space, source_mesh, allow_missing_dofs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't this just be a cached property?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

# The VOM used for cross-mesh interpolation depends on only these
return (
target_space,
source_mesh,
allow_missing_dofs,
)

@cached_property
def vom(self) -> MeshGeometry:
"""Access the VertexOnlyMesh consisting of the target space's dofs immersed in the
source space's mesh.

Returns
-------
tuple[Interpolate, Interpolate]
A tuple containing the point evaluation interpolation and the
input-ordering interpolation.
MeshGeometry
The VertexOnlyMesh.
"""
key = self._vom_cache_key(
self.target_space,
self.source_mesh,
self.allow_missing_dofs,
)
try:
return _vom_cache[key]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like this. The communicator is only implicitly part of this cache key so I think deadlocks could well happen.
Also you are using source_mesh as the key, creating a reference to it and meaning that it will never be cleared from the memory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the correct way to do this then?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Step 1 is choosing an object to tie the lifetime of the VoM to. It's probably best to use source_mesh here.

I have implemented a fair amount of caching support that would easily handle this in pyop3, so you may want to wait for that or just port the (small) amount of code over that you need.

except KeyError:
vom = self._create_vom()
_vom_cache[key] = vom
return vom

def _create_vom(self) -> MeshGeometry:
"""Create the VertexOnlyMesh consisting of the target space's dofs immersed in the
source space's mesh.

Returns
-------
MeshGeometry
The VertexOnlyMesh.

Raises
------
Expand All @@ -483,7 +516,7 @@ def _get_symbolic_expressions(self) -> tuple[Interpolate, Interpolate]:
f_dest_node_coords = assemble(interpolate(self.target_mesh.coordinates, target_space_vec))
dest_node_coords = f_dest_node_coords.dat.data_ro.reshape(-1, self.target_mesh.geometric_dimension)
try:
vom = VertexOnlyMesh(
source_vom = VertexOnlyMesh(
self.source_mesh,
dest_node_coords,
redundant=False,
Expand All @@ -495,7 +528,18 @@ def _get_symbolic_expressions(self) -> tuple[Interpolate, Interpolate]:
f"source function space on domain {self.source_mesh}. "
"This may be because the target mesh covers a larger domain than the "
"source mesh. To disable this error, set allow_missing_dofs=True.")
return source_vom

def _get_symbolic_expressions(self) -> tuple[Interpolate, Interpolate]:
"""Return the symbolic ``Interpolate`` expressions for point evaluation and
re-ordering into the input-ordering VertexOnlyMesh.

Returns
-------
tuple[Interpolate, Interpolate]
A tuple containing the point evaluation interpolation and the
input-ordering interpolation.
"""
# Get the correct type of function space
shape = self.target_space.ufl_function_space().value_shape
if len(shape) == 0:
Expand All @@ -507,11 +551,11 @@ def _get_symbolic_expressions(self) -> tuple[Interpolate, Interpolate]:
fs_type = partial(TensorFunctionSpace, shape=shape, symmetry=symmetry)

# Get expression for point evaluation at the dest_node_coords
P0DG_vom = fs_type(vom, "DG", 0)
P0DG_vom = fs_type(self.vom, "DG", 0)
point_eval = interpolate(self.operand, P0DG_vom)

# Interpolate into the input-ordering VOM
P0DG_vom_input_ordering = fs_type(vom.input_ordering, "DG", 0)
P0DG_vom_input_ordering = fs_type(self.vom.input_ordering, "DG", 0)

arg = Argument(P0DG_vom, 0 if self.ufl_interpolate.is_adjoint else 1)
point_eval_input_ordering = interpolate(arg, P0DG_vom_input_ordering)
Expand Down