diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index f71a08445e..0b23a4822a 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -7,6 +7,7 @@ from typing import Hashable, Literal, Callable, Iterable from dataclasses import asdict, dataclass from numbers import Number +from weakref import WeakValueDictionary from ufl.algorithms import extract_arguments, replace from ufl.domain import extract_unique_domain @@ -64,6 +65,8 @@ "Interpolator" ) +_vom_cache = WeakValueDictionary() + @dataclass(kw_only=True) class InterpolateOptions: @@ -462,15 +465,45 @@ def __init__(self, expr: Interpolate): # scalar fiat/finat element self.dest_element = dest_element - def _get_symbolic_expressions(self) -> tuple[Interpolate, Interpolate]: - """Return the symbolic ``Interpolate`` expressions for point evaluation and - re-ordering into the input-ordering VertexOnlyMesh. + @staticmethod + def _vom_cache_key(target_space, source_mesh, allow_missing_dofs): + # The VOM used for cross-mesh interpolation depends on only these + return ( + target_space, + source_mesh, + allow_missing_dofs, + ) + + @cached_property + def vom(self) -> MeshGeometry: + """Access the VertexOnlyMesh consisting of the target space's dofs immersed in the + source space's mesh. Returns ------- - tuple[Interpolate, Interpolate] - A tuple containing the point evaluation interpolation and the - input-ordering interpolation. + MeshGeometry + The VertexOnlyMesh. + """ + key = self._vom_cache_key( + self.target_space, + self.source_mesh, + self.allow_missing_dofs, + ) + try: + return _vom_cache[key] + except KeyError: + vom = self._create_vom() + _vom_cache[key] = vom + return vom + + def _create_vom(self) -> MeshGeometry: + """Create the VertexOnlyMesh consisting of the target space's dofs immersed in the + source space's mesh. + + Returns + ------- + MeshGeometry + The VertexOnlyMesh. Raises ------ @@ -483,7 +516,7 @@ def _get_symbolic_expressions(self) -> tuple[Interpolate, Interpolate]: f_dest_node_coords = assemble(interpolate(self.target_mesh.coordinates, target_space_vec)) dest_node_coords = f_dest_node_coords.dat.data_ro.reshape(-1, self.target_mesh.geometric_dimension) try: - vom = VertexOnlyMesh( + source_vom = VertexOnlyMesh( self.source_mesh, dest_node_coords, redundant=False, @@ -495,7 +528,18 @@ def _get_symbolic_expressions(self) -> tuple[Interpolate, Interpolate]: f"source function space on domain {self.source_mesh}. " "This may be because the target mesh covers a larger domain than the " "source mesh. To disable this error, set allow_missing_dofs=True.") + return source_vom + + def _get_symbolic_expressions(self) -> tuple[Interpolate, Interpolate]: + """Return the symbolic ``Interpolate`` expressions for point evaluation and + re-ordering into the input-ordering VertexOnlyMesh. + Returns + ------- + tuple[Interpolate, Interpolate] + A tuple containing the point evaluation interpolation and the + input-ordering interpolation. + """ # Get the correct type of function space shape = self.target_space.ufl_function_space().value_shape if len(shape) == 0: @@ -507,11 +551,11 @@ def _get_symbolic_expressions(self) -> tuple[Interpolate, Interpolate]: fs_type = partial(TensorFunctionSpace, shape=shape, symmetry=symmetry) # Get expression for point evaluation at the dest_node_coords - P0DG_vom = fs_type(vom, "DG", 0) + P0DG_vom = fs_type(self.vom, "DG", 0) point_eval = interpolate(self.operand, P0DG_vom) # Interpolate into the input-ordering VOM - P0DG_vom_input_ordering = fs_type(vom.input_ordering, "DG", 0) + P0DG_vom_input_ordering = fs_type(self.vom.input_ordering, "DG", 0) arg = Argument(P0DG_vom, 0 if self.ufl_interpolate.is_adjoint else 1) point_eval_input_ordering = interpolate(arg, P0DG_vom_input_ordering)