|
26 | 26 | __all__ = ["VTKFile", "XDMFFile", "cell_perm_gmsh", "cell_perm_vtk", "distribute_entity_data"] |
27 | 27 |
|
28 | 28 |
|
29 | | -def _extract_cpp_objects(functions: typing.Union[Mesh, Function, tuple[Function], list[Function]]): |
30 | | - """Extract C++ objects""" |
31 | | - if isinstance(functions, (list, tuple)): |
32 | | - return [getattr(u, "_cpp_object", u) for u in functions] |
33 | | - else: |
34 | | - return [getattr(functions, "_cpp_object", functions)] |
35 | | - |
36 | | - |
37 | 29 | # VTXWriter requires ADIOS2 |
38 | 30 | if _cpp.common.has_adios2: |
39 | 31 | from dolfinx.cpp.io import VTXMeshPolicy # F401 |
@@ -81,27 +73,29 @@ def __init__( |
81 | 73 | have the same element type. |
82 | 74 | """ |
83 | 75 | # Get geometry type |
84 | | - try: |
85 | | - dtype = output.geometry.x.dtype # type: ignore |
86 | | - except AttributeError: |
87 | | - try: |
88 | | - dtype = output.function_space.mesh.geometry.x.dtype # type: ignore |
89 | | - except AttributeError: |
90 | | - dtype = output[0].function_space.mesh.geometry.x.dtype # type: ignore |
| 76 | + if isinstance(output, Mesh): |
| 77 | + dtype = output.geometry.x.dtype |
| 78 | + elif isinstance(output, Function): |
| 79 | + dtype = output.function_space.mesh.geometry.x.dtype |
| 80 | + else: |
| 81 | + dtype = output[0].function_space.mesh.geometry.x.dtype |
91 | 82 |
|
92 | 83 | if np.issubdtype(dtype, np.float32): |
93 | 84 | _vtxwriter = _cpp.io.VTXWriter_float32 |
94 | 85 | elif np.issubdtype(dtype, np.float64): |
95 | 86 | _vtxwriter = _cpp.io.VTXWriter_float64 |
| 87 | + else: |
| 88 | + raise RuntimeError(f"VTXWriter does not support dtype={dtype}.") |
96 | 89 |
|
97 | | - try: |
98 | | - # Input is a mesh |
| 90 | + if isinstance(output, Mesh): |
99 | 91 | self._cpp_object = _vtxwriter(comm, filename, output._cpp_object, engine) # type: ignore[union-attr] |
100 | | - except (NotImplementedError, TypeError, AttributeError): |
101 | | - # Input is a single function or a list of functions |
102 | | - self._cpp_object = _vtxwriter( |
103 | | - comm, filename, _extract_cpp_objects(output), engine, mesh_policy |
104 | | - ) # type: ignore[arg-type] |
| 92 | + else: |
| 93 | + cpp_objects = ( |
| 94 | + [output._cpp_object] |
| 95 | + if isinstance(output, Function) |
| 96 | + else [o._cpp_object for o in output] |
| 97 | + ) |
| 98 | + self._cpp_object = _vtxwriter(comm, filename, cpp_objects, engine, mesh_policy) |
105 | 99 |
|
106 | 100 | def __enter__(self): |
107 | 101 | return self |
@@ -137,7 +131,8 @@ def write_mesh(self, mesh: Mesh, t: float = 0.0) -> None: |
137 | 131 | def write_function(self, u: typing.Union[list[Function], Function], t: float = 0.0) -> None: |
138 | 132 | """Write a single function or a list of functions to file for a |
139 | 133 | given time (default 0.0)""" |
140 | | - super().write(_extract_cpp_objects(u), t) |
| 134 | + cpp_objects = [u._cpp_object] if isinstance(u, Function) else [_u._cpp_object for _u in u] |
| 135 | + super().write(cpp_objects, t) |
141 | 136 |
|
142 | 137 |
|
143 | 138 | class XDMFFile(_cpp.io.XDMFFile): |
|
0 commit comments