diff --git a/torchrec/modules/hash_mc_modules.py b/torchrec/modules/hash_mc_modules.py index a29dc3111..f7d566e99 100644 --- a/torchrec/modules/hash_mc_modules.py +++ b/torchrec/modules/hash_mc_modules.py @@ -305,7 +305,9 @@ def __init__( self._max_probe = max_probe self._buckets = total_num_buckets - self.register_buffer("_hash_zch_bucket", torch.tensor([total_num_buckets])) + self.register_buffer( + "_hash_zch_bucket", torch.tensor([total_num_buckets]), persistent=False + ) # Do not need to store in buffer since this is created and consumed # at each step https://fburl.com/code/axzimmbx self._evicted_indices = []