Skip to content

Commit b35aab5

Browse files
angus-gdhamconnorjward
authored
Remove reference cycle in VecAccessMixin (#4033)
* Remove reference cycle in VecAccessMixin With an associated PETSc Vec, VecAccessMixin deferred its version property to a lambda to avoid allocating the storage until necessary. Unfortunately, this lambda creates a reference cycle to self for all users of the VecAccessMixin. Given that counter accesses should be relatively infrequent, it seems fine to look up the counter type within the method itself. * Don't cache topological on CoordinatelessFunction Doesn't make sense to cache a reference to self, just return self. * Convert increment_dat_version to abstract on DataCarrier The inheritance chain for Dat and Global puts VecAccessMixin (rightly) behind DataCarrier. This means that by MRO, the increment_dat_version method provided on DataCarrier will be used, which is a null operation. I think this makes sense, given that not all classes use this. However, because we're providing increment_dat_version as an override through VecAccessMixin, we need to explicitly refer to it in the inheriting classes. * Update pyop2/types/dat.py Otherwise updating a dat through a view doesn't increment the version. Co-authored-by: Connor Ward <c.ward20@imperial.ac.uk> --------- Co-authored-by: David A. Ham <david.ham@imperial.ac.uk> Co-authored-by: Connor Ward <c.ward20@imperial.ac.uk>
1 parent a4787d1 commit b35aab5

6 files changed

Lines changed: 32 additions & 16 deletions

File tree

firedrake/function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(self, function_space, val=None, name=None, dtype=ScalarType):
8484
else:
8585
self.dat = function_space.make_dat(val, dtype, self.name())
8686

87-
@utils.cached_property
87+
@property
8888
def topological(self):
8989
r"""The underlying coordinateless function."""
9090
return self

firedrake/preconditioners/patch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ def _wrapper_cache_key_(self):
127127
def __call__(self, access, map_=None):
128128
return LocalDatLegacyArg(self, map_, access)
129129

130+
def increment_dat_version(self):
131+
pass
132+
130133

131134
register_petsc_function("MatSetValues")
132135

pyop2/types/dat.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,9 @@ def __init__(self, dat, index):
725725
dtype=dat.dtype,
726726
name="view[%s](%s)" % (index, dat.name))
727727

728+
def increment_dat_version(self):
729+
self._parent.increment_dat_version()
730+
728731
@utils.cached_property
729732
def _kernel_args_(self):
730733
return self._parent._kernel_args_
@@ -841,6 +844,9 @@ def vec_context(self, access):
841844
if access is not Access.READ:
842845
self.halo_valid = False
843846

847+
def increment_dat_version(self):
848+
VecAccessMixin.increment_dat_version(self)
849+
844850

845851
class MixedDat(AbstractDat, VecAccessMixin):
846852
r"""A container for a bag of :class:`Dat`\s.

pyop2/types/data_carrier.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def cdim(self):
4444
the product of the dim tuple."""
4545
return self._cdim
4646

47+
@abc.abstractmethod
4748
def increment_dat_version(self):
4849
pass
4950

@@ -80,24 +81,21 @@ def _is_allocated(self):
8081
class VecAccessMixin(abc.ABC):
8182

8283
def __init__(self, petsc_counter=None):
83-
if petsc_counter:
84-
# Use lambda since `_vec` allocates the data buffer
85-
# -> Dat/Global should not allocate storage until accessed
86-
self._dat_version = lambda: self._vec.stateGet()
87-
self.increment_dat_version = lambda: self._vec.stateIncrease()
88-
else:
89-
# No associated PETSc Vec if incompatible type:
90-
# -> Equip Dat/Global with their own counter.
91-
self._version = 0
92-
self._dat_version = lambda: self._version
93-
94-
def _inc():
95-
self._version += 1
96-
self.increment_dat_version = _inc
84+
self._petsc_counter = petsc_counter
85+
self._version = 0
9786

9887
@property
9988
def dat_version(self):
100-
return self._dat_version()
89+
if self._petsc_counter:
90+
return self._vec.stateGet()
91+
92+
return self._version
93+
94+
def increment_dat_version(self):
95+
if self._petsc_counter:
96+
self._vec.stateIncrease()
97+
else:
98+
self._version += 1
10199

102100
@abc.abstractmethod
103101
def vec_context(self, access):

pyop2/types/glob.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,9 @@ def vec_context(self, access):
399399
data = self._data
400400
self.comm.Bcast(data, 0)
401401

402+
def increment_dat_version(self):
403+
VecAccessMixin.increment_dat_version(self)
404+
402405

403406
# has no comm, can only be READ
404407
class Constant(SetFreeDataCarrier):
@@ -464,3 +467,6 @@ def duplicate(self):
464467
dtype=self.dtype,
465468
name=self.name
466469
)
470+
471+
def increment_dat_version(self):
472+
pass

pyop2/types/mat.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,9 @@ def __repr__(self):
550550
return "Mat(%r, %r, %r)" \
551551
% (self._sparsity, self._datatype, self._name)
552552

553+
def increment_dat_version(self):
554+
pass
555+
553556

554557
class Mat(AbstractMat):
555558
"""OP2 matrix data. A Mat is defined on a sparsity pattern and holds a value

0 commit comments

Comments
 (0)