-
Notifications
You must be signed in to change notification settings - Fork 178
Cache VertexOnlyMesh for cross-mesh interpolation #4860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,7 @@ | |
| from typing import Hashable, Literal, Callable, Iterable | ||
| from dataclasses import asdict, dataclass | ||
| from numbers import Number | ||
| from weakref import WeakValueDictionary | ||
|
|
||
| from ufl.algorithms import extract_arguments, replace | ||
| from ufl.domain import extract_unique_domain | ||
|
|
@@ -64,6 +65,8 @@ | |
| "Interpolator" | ||
| ) | ||
|
|
||
| _vom_cache = WeakValueDictionary() | ||
|
|
||
|
|
||
| @dataclass(kw_only=True) | ||
| class InterpolateOptions: | ||
|
|
@@ -462,15 +465,45 @@ def __init__(self, expr: Interpolate): | |
| # scalar fiat/finat element | ||
| self.dest_element = dest_element | ||
|
|
||
| def _get_symbolic_expressions(self) -> tuple[Interpolate, Interpolate]: | ||
| """Return the symbolic ``Interpolate`` expressions for point evaluation and | ||
| re-ordering into the input-ordering VertexOnlyMesh. | ||
| @staticmethod | ||
| def _vom_cache_key(target_space, source_mesh, allow_missing_dofs): | ||
| # The VOM used for cross-mesh interpolation depends on only these | ||
| return ( | ||
| target_space, | ||
| source_mesh, | ||
| allow_missing_dofs, | ||
| ) | ||
|
|
||
| @cached_property | ||
| def vom(self) -> MeshGeometry: | ||
| """Access the VertexOnlyMesh consisting of the target space's dofs immersed in the | ||
| source space's mesh. | ||
|
|
||
| Returns | ||
| ------- | ||
| tuple[Interpolate, Interpolate] | ||
| A tuple containing the point evaluation interpolation and the | ||
| input-ordering interpolation. | ||
| MeshGeometry | ||
| The VertexOnlyMesh. | ||
| """ | ||
| key = self._vom_cache_key( | ||
| self.target_space, | ||
| self.source_mesh, | ||
| self.allow_missing_dofs, | ||
| ) | ||
| try: | ||
| return _vom_cache[key] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't really like this. The communicator is only implicitly part of this cache key so I think deadlocks could well happen.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the correct way to do this then?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Step 1 is choosing an object to tie the lifetime of the VoM to. It's probably best to use I have implemented a fair amount of caching support that would easily handle this in pyop3, so you may want to wait for that or just port the (small) amount of code over that you need. |
||
| except KeyError: | ||
| vom = self._create_vom() | ||
| _vom_cache[key] = vom | ||
| return vom | ||
|
|
||
| def _create_vom(self) -> MeshGeometry: | ||
| """Create the VertexOnlyMesh consisting of the target space's dofs immersed in the | ||
| source space's mesh. | ||
|
|
||
| Returns | ||
| ------- | ||
| MeshGeometry | ||
| The VertexOnlyMesh. | ||
|
|
||
| Raises | ||
| ------ | ||
|
|
@@ -483,7 +516,7 @@ def _get_symbolic_expressions(self) -> tuple[Interpolate, Interpolate]: | |
| f_dest_node_coords = assemble(interpolate(self.target_mesh.coordinates, target_space_vec)) | ||
| dest_node_coords = f_dest_node_coords.dat.data_ro.reshape(-1, self.target_mesh.geometric_dimension) | ||
| try: | ||
| vom = VertexOnlyMesh( | ||
| source_vom = VertexOnlyMesh( | ||
| self.source_mesh, | ||
| dest_node_coords, | ||
| redundant=False, | ||
|
|
@@ -495,7 +528,18 @@ def _get_symbolic_expressions(self) -> tuple[Interpolate, Interpolate]: | |
| f"source function space on domain {self.source_mesh}. " | ||
| "This may be because the target mesh covers a larger domain than the " | ||
| "source mesh. To disable this error, set allow_missing_dofs=True.") | ||
| return source_vom | ||
|
|
||
| def _get_symbolic_expressions(self) -> tuple[Interpolate, Interpolate]: | ||
| """Return the symbolic ``Interpolate`` expressions for point evaluation and | ||
| re-ordering into the input-ordering VertexOnlyMesh. | ||
|
|
||
| Returns | ||
| ------- | ||
| tuple[Interpolate, Interpolate] | ||
| A tuple containing the point evaluation interpolation and the | ||
| input-ordering interpolation. | ||
| """ | ||
| # Get the correct type of function space | ||
| shape = self.target_space.ufl_function_space().value_shape | ||
| if len(shape) == 0: | ||
|
|
@@ -507,11 +551,11 @@ def _get_symbolic_expressions(self) -> tuple[Interpolate, Interpolate]: | |
| fs_type = partial(TensorFunctionSpace, shape=shape, symmetry=symmetry) | ||
|
|
||
| # Get expression for point evaluation at the dest_node_coords | ||
| P0DG_vom = fs_type(vom, "DG", 0) | ||
| P0DG_vom = fs_type(self.vom, "DG", 0) | ||
| point_eval = interpolate(self.operand, P0DG_vom) | ||
|
|
||
| # Interpolate into the input-ordering VOM | ||
| P0DG_vom_input_ordering = fs_type(vom.input_ordering, "DG", 0) | ||
| P0DG_vom_input_ordering = fs_type(self.vom.input_ordering, "DG", 0) | ||
|
|
||
| arg = Argument(P0DG_vom, 0 if self.ufl_interpolate.is_adjoint else 1) | ||
| point_eval_input_ordering = interpolate(arg, P0DG_vom_input_ordering) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't this just be a cached property?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes