From 89735a23814f11ed0e07f4778ef9231bfe1102c1 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Sun, 3 Aug 2025 09:42:53 +0200 Subject: [PATCH 1/2] Simplify `WaitGroup` implementation --- src/storage.rs | 41 ++++++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/src/storage.rs b/src/storage.rs index f63981e4f..309ea7cc5 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -1,5 +1,6 @@ //! Public API facades for the implementation details of [`Zalsa`] and [`ZalsaLocal`]. use std::marker::PhantomData; +use std::mem; use std::panic::RefUnwindSafe; use crate::database::RawDatabase; @@ -25,8 +26,6 @@ pub struct StorageHandle { impl Clone for StorageHandle { fn clone(&self) -> Self { - *self.coordinate.clones.lock() += 1; - Self { zalsa_impl: self.zalsa_impl.clone(), coordinate: CoordinateDrop(Arc::clone(&self.coordinate)), @@ -53,7 +52,7 @@ impl StorageHandle { Self { zalsa_impl: Arc::new(Zalsa::new::(event_callback, jars)), coordinate: CoordinateDrop(Arc::new(Coordinate { - clones: Mutex::new(1), + coordinate_lock: Mutex::default(), cvar: Default::default(), })), phantom: PhantomData, @@ -95,17 +94,6 @@ impl Drop for Storage { } } -struct Coordinate { - /// Counter of the number of clones of actor. Begins at 1. - /// Incremented when cloned, decremented when dropped. - clones: Mutex, - cvar: Condvar, -} - -// We cannot panic while holding a lock to `clones: Mutex` and therefore we cannot enter an -// inconsistent state. -impl RefUnwindSafe for Coordinate {} - impl Default for Storage { fn default() -> Self { Self::new(None) @@ -168,12 +156,14 @@ impl Storage { .zalsa_impl .event(&|| Event::new(EventKind::DidSetCancellationFlag)); - let mut clones = self.handle.coordinate.clones.lock(); - while *clones != 1 { - clones = self.handle.coordinate.cvar.wait(clones); - } - // The ref count on the `Arc` should now be 1 - let zalsa = Arc::get_mut(&mut self.handle.zalsa_impl).unwrap(); + let mut coordinate_lock = self.handle.coordinate.coordinate_lock.lock(); + let zalsa = loop { + if let Some(zalsa) = Arc::get_mut(&mut self.handle.zalsa_impl) { + // SAFETY: Polonius when ... https://github.com/rust-lang/rfcs/blob/master/text/2094-nll.md#problem-case-3-conditional-control-flow-across-functions + break unsafe { mem::transmute::<&mut Zalsa, &mut Zalsa>(zalsa) }; + } + coordinate_lock = self.handle.coordinate.cvar.wait(coordinate_lock); + }; // cancellation is done, so reset the flag zalsa.runtime_mut().reset_cancellation_flag(); zalsa @@ -260,6 +250,16 @@ impl Clone for Storage { } } +/// A simplified `WaitGroup`, this is used together with `Arc` as the actual counter +struct Coordinate { + coordinate_lock: Mutex<()>, + cvar: Condvar, +} + +// We cannot panic while holding a lock to `clones: Mutex` and therefore we cannot enter an +// inconsistent state. +impl RefUnwindSafe for Coordinate {} + struct CoordinateDrop(Arc); impl std::ops::Deref for CoordinateDrop { @@ -272,7 +272,6 @@ impl std::ops::Deref for CoordinateDrop { impl Drop for CoordinateDrop { fn drop(&mut self) { - *self.0.clones.lock() -= 1; self.0.cvar.notify_all(); } } From 9b4da363a4930e8070b95df981420b579c07c0e3 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Mon, 20 Oct 2025 11:48:26 +0200 Subject: [PATCH 2/2] Slightly cheaper `get_mut` Co-authored-by: Ibraheem Ahmed --- src/storage.rs | 8 ++++---- src/table.rs | 5 ++++- src/views.rs | 6 +++++- src/zalsa_local.rs | 4 ++-- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/storage.rs b/src/storage.rs index 309ea7cc5..443b53221 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -1,6 +1,5 @@ //! Public API facades for the implementation details of [`Zalsa`] and [`ZalsaLocal`]. use std::marker::PhantomData; -use std::mem; use std::panic::RefUnwindSafe; use crate::database::RawDatabase; @@ -158,9 +157,10 @@ impl Storage { let mut coordinate_lock = self.handle.coordinate.coordinate_lock.lock(); let zalsa = loop { - if let Some(zalsa) = Arc::get_mut(&mut self.handle.zalsa_impl) { - // SAFETY: Polonius when ... https://github.com/rust-lang/rfcs/blob/master/text/2094-nll.md#problem-case-3-conditional-control-flow-across-functions - break unsafe { mem::transmute::<&mut Zalsa, &mut Zalsa>(zalsa) }; + if Arc::strong_count(&self.handle.zalsa_impl) == 1 { + // SAFETY: The strong count is 1, and we never create any weak pointers, + // so we have a unique reference. + break unsafe { &mut *(Arc::as_ptr(&self.handle.zalsa_impl).cast_mut()) }; } coordinate_lock = self.handle.coordinate.cvar.wait(coordinate_lock); }; diff --git a/src/table.rs b/src/table.rs index 53cf10cce..5505c1c05 100644 --- a/src/table.rs +++ b/src/table.rs @@ -252,7 +252,10 @@ impl Table { } let allocated_idx = self.push_page::(ingredient, memo_types.clone()); - assert_eq!(allocated_idx, page_idx); + assert_eq!( + allocated_idx, page_idx, + "allocated index does not match requested index" + ); } }; } diff --git a/src/views.rs b/src/views.rs index d449779c3..d58f349f0 100644 --- a/src/views.rs +++ b/src/views.rs @@ -108,7 +108,11 @@ impl Views { &self, func: fn(NonNull) -> NonNull, ) -> &DatabaseDownCaster { - assert_eq!(self.source_type_id, TypeId::of::()); + assert_eq!( + self.source_type_id, + TypeId::of::(), + "mismatched source type" + ); let target_type_id = TypeId::of::(); if let Some((_, caster)) = self .view_casters diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 39d0c489c..7b0399178 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -1173,7 +1173,7 @@ impl ActiveQueryGuard<'_> { unsafe { self.local_state.with_query_stack_unchecked_mut(|stack| { #[cfg(debug_assertions)] - assert_eq!(stack.len(), self.push_len); + assert_eq!(stack.len(), self.push_len, "mismatched push and pop"); let frame = stack.last_mut().unwrap(); frame.tracked_struct_ids_mut().seed(tracked_struct_ids); }) @@ -1195,7 +1195,7 @@ impl ActiveQueryGuard<'_> { unsafe { self.local_state.with_query_stack_unchecked_mut(|stack| { #[cfg(debug_assertions)] - assert_eq!(stack.len(), self.push_len); + assert_eq!(stack.len(), self.push_len, "mismatched push and pop"); let frame = stack.last_mut().unwrap(); frame.seed_iteration(durability, changed_at, edges, untracked_read, tracked_ids); })