diff --git a/documentation/modules/ROOT/partials/release-notes/release-notes-6.0.2.adoc b/documentation/modules/ROOT/partials/release-notes/release-notes-6.0.2.adoc index 74a9291a1a3f..adc2f57d5746 100644 --- a/documentation/modules/ROOT/partials/release-notes/release-notes-6.0.2.adoc +++ b/documentation/modules/ROOT/partials/release-notes/release-notes-6.0.2.adoc @@ -17,6 +17,13 @@ repository on GitHub. ==== Bug Fixes * Make `ConsoleLauncher` compatible with JDK 26 by avoiding final field mutations. +* Fix a concurrency bug in `NamespacedHierarchicalStore#computeIfAbsent(Object, Object, Function)` where + the `defaultCreator` function was executed while holding the store's internal + map lock. Under parallel execution, this could cause threads using the store to + block each other and temporarily see a missing or incorrectly initialized state + for values created via `computeIfAbsent`. The method now evaluates + `defaultCreator` outside the critical section using a memorizing supplier, + aligning its behavior with the deprecated `getOrComputeIfAbsent`. [[v6.0.2-junit-platform-deprecations-and-breaking-changes]] ==== Deprecations and Breaking Changes diff --git a/junit-platform-engine/src/main/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStore.java b/junit-platform-engine/src/main/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStore.java index baeeaed7f122..d9224308b6fa 100644 --- a/junit-platform-engine/src/main/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStore.java +++ b/junit-platform-engine/src/main/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStore.java @@ -214,7 +214,7 @@ public void close() { StoredValue storedValue = getStoredValue(compositeKey); if (storedValue == null) { storedValue = this.storedValues.computeIfAbsent(compositeKey, - __ -> newStoredValue(new MemoizingSupplier(() -> { + __ -> newStoredValue(new MemorizingSupplier(() -> { rejectIfClosed(); return defaultCreator.apply(key); }))); @@ -240,27 +240,35 @@ public void close() { @API(status = MAINTAINED, since = "6.0") public Object computeIfAbsent(N namespace, K key, Function defaultCreator) { Preconditions.notNull(defaultCreator, "defaultCreator must not be null"); - CompositeKey compositeKey = new CompositeKey<>(namespace, key); - StoredValue storedValue = getStoredValue(compositeKey); + var compositeKey = new CompositeKey<>(namespace, key); + var storedValue = getStoredValue(compositeKey); var result = StoredValue.evaluateIfNotNull(storedValue); if (result == null) { - StoredValue newStoredValue = this.storedValues.compute(compositeKey, (__, oldStoredValue) -> { - if (StoredValue.evaluateIfNotNull(oldStoredValue) == null) { - rejectIfClosed(); - var computedValue = Preconditions.notNull(defaultCreator.apply(key), - "defaultCreator must not return null"); - return newStoredValue(() -> { - rejectIfClosed(); - return computedValue; - }); + var value = storedValues.compute(compositeKey, + (__, currentValue) -> currentValue == null || currentValue.equals(storedValue) + ? storeNewValue(key, defaultCreator) + : currentValue); + try { + return requireNonNull(value.evaluate()); + } + catch (Throwable t) { // remove failed entry to allow retry. + if (value.equals(storedValues.get(compositeKey))) { + storedValues.remove(compositeKey, value); } - return oldStoredValue; - }); - return requireNonNull(newStoredValue.evaluate()); + throw t; + } } return result; } + private StoredValue storeNewValue(K key, Function defaultCreator) { + rejectIfClosed(); + return newStoredValue(new MemorizingSupplier(() -> { + rejectIfClosed(); + return requireNonNull(defaultCreator.apply(key)); + })); + } + /** * Get the value stored for the supplied namespace and key in this store or * the parent store, if present, or call the supplied function to compute it @@ -469,7 +477,7 @@ private void close(CloseAction closeAction) throws Throwable { * * @see StoredValue */ - private static class MemoizingSupplier implements Supplier<@Nullable Object> { + private static class MemorizingSupplier implements Supplier<@Nullable Object> { private static final Object NO_VALUE_SET = new Object(); @@ -478,7 +486,7 @@ private static class MemoizingSupplier implements Supplier<@Nullable Object> { @Nullable private volatile Object value = NO_VALUE_SET; - private MemoizingSupplier(Supplier<@Nullable Object> delegate) { + private MemorizingSupplier(Supplier<@Nullable Object> delegate) { this.delegate = delegate; } diff --git a/platform-tests/src/test/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStoreTests.java b/platform-tests/src/test/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStoreTests.java index a239db0cae2e..38aafcab9066 100644 --- a/platform-tests/src/test/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStoreTests.java +++ b/platform-tests/src/test/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStoreTests.java @@ -10,9 +10,13 @@ package org.junit.platform.engine.support.store; +import static java.util.Collections.synchronizedList; import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -25,7 +29,11 @@ import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; +import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; @@ -672,4 +680,882 @@ public String toString() { } }; } + + @Test + void computeIfAbsentOverridesParentNullValue() { + var parent = new NamespacedHierarchicalStore(null); + var child = parent.newChild(); + + // Store null in parent + parent.put("ns", "key", null); + + // Initially child should see null from parent + assertNull(child.get("ns", "key")); + + // computeIfAbsent should treat null as "logically absent" and compute a new value + var result = child.computeIfAbsent("ns", "key", __ -> "value"); + assertEquals("value", result); + + // Subsequent get should return the computed value + assertEquals("value", child.get("ns", "key")); + + // Parent should still have null + assertNull(parent.get("ns", "key")); + } + + @Test + void computeIfAbsentWithRecursiveStoreAccess() throws Exception { + // This test simulates the AssertJ scenario where computeIfAbsent + // calls functions that also access the store + var store = new NamespacedHierarchicalStore(null); + + var recursiveCounter = new AtomicInteger(); + + // This mimics AssertJ's SoftAssertionsExtension where the creator + // function also accesses the store + Function recursiveCreator = key -> { + recursiveCounter.incrementAndGet(); + // Access store while computing value (like AssertJ does) + store.put("other", "key", "nested"); + return "value"; + }; + + // Should not throw "Recursive update" exception + assertDoesNotThrow(() -> { + var result = store.computeIfAbsent("ns", "key", recursiveCreator); + assertEquals("value", result); + }); + + assertEquals(1, recursiveCounter.get()); + assertEquals("nested", store.get("other", "key")); + } + + @Test + void computeIfAbsentWithExceptionThrowingCreatorDoesNotLeaveCorruptState() { + var store = new NamespacedHierarchicalStore(null); + + RuntimeException exception = new RuntimeException("Boom!"); + + // First call fails + assertThrows(RuntimeException.class, () -> store.computeIfAbsent("ns", "key", __ -> { + throw exception; + })); + + // Subsequent calls should be able to retry + assertDoesNotThrow(() -> { + var result = store.computeIfAbsent("ns", "key", __ -> "success"); + assertEquals("success", result); + }); + + // Final state should be correct + assertEquals("success", store.get("ns", "key")); + } + + @Test + void computeIfAbsentMaintainsAtomicInitializationUnderConcurrency() throws Exception { + int threadCount = 10; + var store = new NamespacedHierarchicalStore(null); + var counter = new AtomicInteger(); + var values = synchronizedList(new ArrayList()); + + List threads = new ArrayList<>(); + for (int i = 0; i < threadCount; i++) { + threads.add(new Thread(() -> { + var value = store.computeIfAbsent("ns", "key", __ -> { + // Simulate expensive initialization + try { + Thread.sleep(10); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return "value-" + counter.incrementAndGet(); + }); + values.add((String) value); + })); + } + + threads.forEach(Thread::start); + for (Thread thread : threads) { + thread.join(); + } + + // Only one thread should have initialized the value + assertEquals(1, counter.get()); + + // All threads should get the same value + var expectedValue = "value-1"; + assertEquals(threadCount, values.size()); + for (String value : values) { + assertEquals(expectedValue, value); + } + } + + @Test + void computeIfAbsentPreservesValueWhenParentHasNonNullValue() { + var parent = new NamespacedHierarchicalStore(null); + var child = parent.newChild(); + + // Parent has non-null value + parent.put("ns", "key", "parent-value"); + + // computeIfAbsent should return parent's value, not compute new one + var counter = new AtomicInteger(); + var result = child.computeIfAbsent("ns", "key", __ -> { + counter.incrementAndGet(); + return "child-value"; + }); + + assertEquals("parent-value", result); + assertEquals(0, counter.get()); // Creator should not be called + } + + @Test + void computeIfAbsentClosedStoreThrowsException() { + var store = new NamespacedHierarchicalStore(null); + store.close(); + + assertThrows(NamespacedHierarchicalStoreException.class, + () -> store.computeIfAbsent("ns", "key", __ -> "value")); + } + + @Test + void computeIfAbsentWithTypeSafetyAndConcurrentAccess() throws Exception { + int threadCount = 5; + var store = new NamespacedHierarchicalStore(null); + var counter = new AtomicInteger(); + + List threads = new ArrayList<>(); + for (int i = 0; i < threadCount; i++) { + threads.add(new Thread(() -> { + var value = store.computeIfAbsent("ns", "key", __ -> counter.incrementAndGet(), Integer.class); + assertNotNull(value); + assertEquals(Integer.class, value.getClass()); + })); + } + + threads.forEach(Thread::start); + for (Thread thread : threads) { + thread.join(); + } + + assertEquals(1, counter.get()); + } + + /** + * #5171 + * #5209 + */ + @Nested + class ConcurrencyIssue5171 { + + /** + * Helper class that forces hash collisions in ConcurrentHashMap. + * This ensures different keys end up in the same bucket, exposing + * potential deadlocks when map locks are held. + */ + private static final class CollidingKey { + private final Object value; + + CollidingKey(Object value) { + this.value = value; + } + + @Override + public boolean equals(Object o) { + return o instanceof CollidingKey && this.value == ((CollidingKey) o).value; + } + + @Override + public int hashCode() { + return 42; // Force all CollidingKey instances to have the same hash code + } + } + + @SuppressWarnings("deprecation") + @Test + void getOrComputeIfAbsentDoesNotDeadlockWithCollidingKeys() throws Exception { + var store = new NamespacedHierarchicalStore(null); + var latch1 = new CountDownLatch(1); + var latch2 = new CountDownLatch(1); + + var thread1 = new Thread(() -> store.getOrComputeIfAbsent("ns", new CollidingKey(1), key -> { + latch1.countDown(); + try { + // Wait for second thread to start its computation + latch2.await(); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + return "value1"; + })); + + var thread2 = new Thread(() -> { + try { + // Wait for first thread to start its computation + latch1.await(); + store.getOrComputeIfAbsent("ns", new CollidingKey(2), key -> { + latch2.countDown(); + return "value2"; + }); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + }); + + thread1.start(); + thread2.start(); + + // Wait with timeout to detect deadlocks + thread1.join(5000); + thread2.join(5000); + + assertFalse(thread1.isAlive(), "Thread1 should have completed (no deadlock)"); + assertFalse(thread2.isAlive(), "Thread2 should have completed (no deadlock)"); + } + + @SuppressWarnings("deprecation") + @Test + void getOrComputeIfAbsentDoesNotDeadlockWithCollidingKeys2() throws Exception { + try (var localStore = new NamespacedHierarchicalStore(null)) { + var firstComputationStarted = new CountDownLatch(1); + var secondComputationAllowedToFinish = new CountDownLatch(1); + var firstThreadTimedOut = new AtomicBoolean(false); + + Thread first = new Thread( + () -> localStore.getOrComputeIfAbsent(namespace, new CollidingKey("k1"), __ -> { + firstComputationStarted.countDown(); + try { + if (!secondComputationAllowedToFinish.await(200, TimeUnit.MILLISECONDS)) { + firstThreadTimedOut.set(true); + } + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return "value1"; + })); + + Thread second = new Thread(() -> { + try { + firstComputationStarted.await(); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + localStore.getOrComputeIfAbsent(namespace, new CollidingKey("k2"), __ -> { + secondComputationAllowedToFinish.countDown(); + return "value2"; + }); + }); + + first.start(); + second.start(); + + first.join(1000); + second.join(1000); + + assertThat(firstThreadTimedOut).as( + "getOrComputeIfAbsent should not block subsequent computations on colliding keys").isFalse(); + } + } + + @Test + void computeIfAbsentCanDeadlockWithCollidingKeys() throws Exception { + try (var localStore = new NamespacedHierarchicalStore(null)) { + var firstComputationStarted = new CountDownLatch(1); + var secondComputationAllowedToFinish = new CountDownLatch(1); + var firstThreadTimedOut = new AtomicBoolean(false); + + Thread first = new Thread(() -> localStore.computeIfAbsent(namespace, new CollidingKey("k1"), __ -> { + firstComputationStarted.countDown(); + try { + if (!secondComputationAllowedToFinish.await(200, TimeUnit.MILLISECONDS)) { + firstThreadTimedOut.set(true); + } + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return "value1"; + })); + + Thread second = new Thread(() -> { + try { + firstComputationStarted.await(); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + localStore.computeIfAbsent(namespace, new CollidingKey("k2"), __ -> { + secondComputationAllowedToFinish.countDown(); + return "value2"; + }); + }); + + first.start(); + second.start(); + + first.join(1000); + second.join(1000); + + assertThat(firstThreadTimedOut).as( + "computeIfAbsent should not block subsequent computations on colliding keys").isFalse(); + } + } + + @Test + void computeIfAbsentOverridesParentNullValue() { + // computeIfAbsent must treat a null value from the parent store as logically absent, + // so the child store can install and keep its own non-null value for the same key. + try (var parent = new NamespacedHierarchicalStore(null); + var child = new NamespacedHierarchicalStore(parent)) { + + parent.put(namespace, key, null); + + assertNull(parent.get(namespace, key)); + assertNull(child.get(namespace, key)); + + Object childValue = child.computeIfAbsent(namespace, key, __ -> "value"); + + assertEquals("value", childValue); + assertEquals("value", child.get(namespace, key)); + } + } + + @Test + void computeIfAbsentDoesNotDeadlockWithCollidingKeys() throws Exception { + var store = new NamespacedHierarchicalStore(null); + var latch1 = new CountDownLatch(1); + var latch2 = new CountDownLatch(1); + + var thread1 = new Thread(() -> store.computeIfAbsent("ns", new CollidingKey(1), key -> { + latch1.countDown(); + try { + // Wait for second thread to start its computation + latch2.await(); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + return "value1"; + })); + + var thread2 = new Thread(() -> { + try { + // Wait for first thread to start its computation + latch1.await(); + store.computeIfAbsent("ns", new CollidingKey(2), key -> { + latch2.countDown(); + return "value2"; + }); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + }); + + thread1.start(); + thread2.start(); + + // Wait with timeout to detect deadlocks + thread1.join(5000); + thread2.join(5000); + + assertFalse(thread1.isAlive(), "Thread1 should have completed (no deadlock)"); + assertFalse(thread2.isAlive(), "Thread2 should have completed (no deadlock)"); + } + + @Test + void computeIfAbsentWithNullParentValueAndLocalNullValue() { + var parent = new NamespacedHierarchicalStore(null); + var child = parent.newChild(); + + // Parent has null value + parent.put("ns", "key", null); + + // Child also has null value (overrides parent null) + child.put("ns", "key", null); + + // computeIfAbsent should compute new value even though child has null + var counter = new AtomicInteger(); + var result = child.computeIfAbsent("ns", "key", __ -> { + counter.incrementAndGet(); + return "computed-value"; + }); + + assertEquals("computed-value", result); + assertEquals(1, counter.get()); + assertEquals("computed-value", child.get("ns", "key")); + // Parent should still have null + assertNull(parent.get("ns", "key")); + } + + @Test + void computeIfAbsentWithConcurrentNullValueResolution() throws Exception { + int threadCount = 5; + var store = new NamespacedHierarchicalStore(null); + var values = synchronizedList(new ArrayList()); + var nullCount = new AtomicInteger(); + + List threads = new ArrayList<>(); + for (int i = 0; i < threadCount; i++) { + threads.add(new Thread(() -> { + var value = store.computeIfAbsent("ns", "key", __ -> { + // Return null sometimes to test null handling + if (nullCount.incrementAndGet() == 1) { + return "null"; + } + return "non-null-value"; + }); + values.add(value.toString()); + })); + } + + threads.forEach(Thread::start); + for (Thread thread : threads) { + thread.join(); + } + + // First thread's creator might return null, causing exception + // Subsequent threads should retry and get a non-null value + assertThat(values).contains("null"); + assertThat(values).hasSize(5); + } + + @Test + void computeIfAbsentRemovesFailedEntryOnException() { + var store = new NamespacedHierarchicalStore(null); + var exceptionCount = new AtomicInteger(); + + // First call throws exception + assertThrows(RuntimeException.class, () -> store.computeIfAbsent("ns", "key", __ -> { + exceptionCount.incrementAndGet(); + throw new RuntimeException("First attempt fails"); + })); + + // Second call should succeed (failed entry was removed) + assertDoesNotThrow(() -> { + var result = store.computeIfAbsent("ns", "key", __ -> "success"); + assertEquals("success", result); + }); + + assertEquals(1, exceptionCount.get()); + assertEquals("success", store.get("ns", "key")); + } + + @Test + void computeIfAbsentWithInterruptedThreadDoesNotLeaveCorruptState() throws Exception { + var store = new NamespacedHierarchicalStore(null); + var latch = new CountDownLatch(1); + + var thread = new Thread(() -> { + try { + store.computeIfAbsent("ns", "key", __ -> { + latch.countDown(); + try { + Thread.sleep(10000); // Sleep indefinitely + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted during computation"); + } + return "value"; + }); + } + catch (RuntimeException e) { + // Expected interruption + } + }); + + thread.start(); + latch.await(); // Wait for thread to start computation + Thread.sleep(100); // Give it a moment to get into the sleep + thread.interrupt(); + thread.join(5000); + + // After interruption, another thread should be able to compute + assertDoesNotThrow(() -> { + var result = store.computeIfAbsent("ns", "key", __ -> "new-value"); + assertEquals("new-value", result); + }); + } + + @Test + void computeIfAbsentWithSameKeyDifferentNamespacesConcurrently() throws Exception { + int threadCount = 10; + var store = new NamespacedHierarchicalStore(null); + var counters = new AtomicInteger[2]; + counters[0] = new AtomicInteger(); + counters[1] = new AtomicInteger(); + + List threads = new ArrayList<>(); + for (int i = 0; i < threadCount; i++) { + final int namespaceIndex = i % 2; + threads.add(new Thread(() -> { + var value = store.computeIfAbsent("ns" + namespaceIndex, "key", + __ -> "value-" + namespaceIndex + "-" + counters[namespaceIndex].incrementAndGet()); + assertEquals("value-" + namespaceIndex + "-1", value); + })); + } + + threads.forEach(Thread::start); + for (Thread thread : threads) { + thread.join(); + } + + // Each namespace should have been initialized only once + assertEquals(1, counters[0].get()); + assertEquals(1, counters[1].get()); + assertEquals("value-0-1", store.get("ns0", "key")); + assertEquals("value-1-1", store.get("ns1", "key")); + } + + @Test + void computeIfAbsentWithHeavyContentionAndDifferentBuckets() throws Exception { + int threadCount = 20; + var store = new NamespacedHierarchicalStore(null); + var counters = new AtomicInteger[threadCount]; + for (int i = 0; i < threadCount; i++) { + counters[i] = new AtomicInteger(); + } + + List threads = new ArrayList<>(); + for (int i = 0; i < threadCount; i++) { + final int keyIndex = i; + threads.add(new Thread(() -> { + var value = store.computeIfAbsent("ns", "key" + keyIndex, + __ -> "value-" + keyIndex + "-" + counters[keyIndex].incrementAndGet()); + assertEquals("value-" + keyIndex + "-1", value); + })); + } + + threads.forEach(Thread::start); + for (Thread thread : threads) { + thread.join(); + } + + // Each key should have been initialized only once + for (int i = 0; i < threadCount; i++) { + assertEquals(1, counters[i].get()); + assertEquals("value-" + i + "-1", store.get("ns", "key" + i)); + } + } + + @Test + void computeIfAbsentWithRecursiveComputationInDifferentNamespace() { + var store = new NamespacedHierarchicalStore(null); + + // Test that computing in one namespace doesn't block computing in another + // even when the computations are recursive + Function recursiveCreator1 = key -> { + store.computeIfAbsent("ns2", key, k -> "ns2-value"); + return "ns1-value"; + }; + + Function recursiveCreator2 = key -> { + store.computeIfAbsent("ns1", key, k -> "ns1-value"); + return "ns2-value"; + }; + + // Should not deadlock + assertDoesNotThrow(() -> { + var result1 = store.computeIfAbsent("ns1", "key", recursiveCreator1); + var result2 = store.computeIfAbsent("ns2", "key", recursiveCreator2); + assertEquals("ns1-value", result1); + assertEquals("ns2-value", result2); + }); + } + + @Test + void computeIfAbsentPreservesOrderOfOperations() throws Exception { + var store = new NamespacedHierarchicalStore(null); + var order = new ArrayList(); + var latch = new CountDownLatch(1); + + Thread thread1 = new Thread(() -> { + store.computeIfAbsent("ns", "key", __ -> { + order.add("thread1-compute-start"); + try { + latch.await(); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + order.add("thread1-compute-end"); + return "value1"; + }); + order.add("thread1-done"); + }); + + Thread thread2 = new Thread(() -> { + order.add("thread2-start"); + latch.countDown(); + var result = store.computeIfAbsent("ns", "key", __ -> { + order.add("thread2-compute"); + return "value2"; + }); + order.add("thread2-done"); + }); + + thread1.start(); + Thread.sleep(50); // Ensure thread1 starts first + thread2.start(); + + thread1.join(5000); + thread2.join(5000); + + // Verify thread2 waits for thread1 to complete + // thread1-compute-start should happen first + // thread2-start can happen while thread1 is waiting + // thread1-compute-end should happen before thread2-done + assertThat(order.indexOf("thread1-compute-start")).isLessThan(order.indexOf("thread1-compute-end")); + assertThat(order.indexOf("thread1-compute-end")).isLessThan(order.indexOf("thread2-done")); + + // thread2 should not execute its compute function + assertThat(order).doesNotContain("thread2-compute"); + } + + @Test + void computeIfAbsentWithExceptionInMemorizingSupplierPropagation() { + var store = new NamespacedHierarchicalStore(null); + + // First call installs a MemorizingSupplier that throws, then removes it + assertThrows(RuntimeException.class, () -> store.computeIfAbsent("ns", "key", __ -> { + throw new RuntimeException("Boom!"); + })); + + // Subsequent calls should NOT get the same exception - the failed entry was removed + // They should be able to retry the computation + var exception2 = assertThrows(RuntimeException.class, () -> store.computeIfAbsent("ns", "key", __ -> { + throw new RuntimeException("Boom again!"); + })); + assertEquals("Boom again!", exception2.getMessage()); + + // Since the entry was removed, get should return null + assertNull(store.get("ns", "key")); + + // Remove should return null since nothing is stored + assertNull(store.remove("ns", "key")); + + // Now a successful computation should work + var result = store.computeIfAbsent("ns", "key", __ -> "success"); + assertEquals("success", result); + assertEquals("success", store.get("ns", "key")); + } + + @Test + void computeIfAbsentRemovesEntryWhenComputeThrowsAndEntryIsStillPresent() { + var store = new NamespacedHierarchicalStore(null); + var shouldThrow = new AtomicBoolean(true); + var computeCount = new AtomicInteger(0); + + // First attempt: throws exception, entry should be removed + assertThrows(RuntimeException.class, () -> store.computeIfAbsent("ns", "key", __ -> { + computeCount.incrementAndGet(); + if (shouldThrow.get()) { + throw new RuntimeException("First attempt fails"); + } + return "success"; + })); + + assertEquals(1, computeCount.get()); + assertNull(store.get("ns", "key")); + + // Second attempt: should succeed since entry was removed + shouldThrow.set(false); + assertDoesNotThrow(() -> { + var result = store.computeIfAbsent("ns", "key", __ -> { + computeCount.incrementAndGet(); + return "success"; + }); + assertEquals("success", result); + }); + + assertEquals(2, computeCount.get()); + assertEquals("success", store.get("ns", "key")); + } + + @Test + void computeIfAbsentDoesNotRemoveEntryWhenRaceConditionReplacesValue() throws Exception { + var store = new NamespacedHierarchicalStore(null); + var latch = new CountDownLatch(1); + var secondThreadStarted = new CountDownLatch(1); + var successfulComputeCount = new AtomicInteger(0); + + // Thread 1: starts computation but gets interrupted + Thread thread1 = new Thread(() -> { + try { + store.computeIfAbsent("ns", "key", __ -> { + latch.countDown(); // Signal that thread1 started computation + try { + // Wait for thread2 to replace the value + secondThreadStarted.await(2, TimeUnit.SECONDS); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted during computation"); + } + return "thread1-value"; + }); + } + catch (RuntimeException e) { + // Expected to fail + } + }); + + // Thread 2: replaces the value while thread1 is waiting + Thread thread2 = new Thread(() -> { + try { + latch.await(); // Wait for thread1 to start + // Put a different value while thread1 is computing + store.put("ns", "key", "thread2-preemptive-put"); + secondThreadStarted.countDown(); // Allow thread1 to continue + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + + thread1.start(); + thread2.start(); + + thread1.join(3000); + thread2.join(3000); + + // Even though thread1's computation threw an exception, + // the entry was already replaced by thread2's put(), + // so thread1 should NOT remove thread2's value + + assertEquals("thread2-preemptive-put", store.get("ns", "key")); + + // Now test that computeIfAbsent still works correctly with the existing value + var result = store.computeIfAbsent("ns", "key", __ -> { + successfulComputeCount.incrementAndGet(); + return "new-value"; + }); + + // Should return the existing value, not compute new one + assertEquals("thread2-preemptive-put", result); + assertEquals(0, successfulComputeCount.get()); + } + + @Test + void computeIfAbsentRemovalOnlyWhenCurrentValueMatchesFailedComputation() { + var store = new NamespacedHierarchicalStore(null); + + // Scenario 1: Value computed successfully by another thread while current thread's computation fails + store.put("ns", "key1", "initial-value"); + + // This computeIfAbsent should see the existing value and not compute + var result1 = store.computeIfAbsent("ns", "key1", __ -> { + // This should never be called since value already exists + throw new RuntimeException("Should not be called"); + }); + + assertEquals("initial-value", result1); + + // Scenario 2: Value gets replaced between compute() call and exception handling + var latch = new CountDownLatch(2); + var barrier = new CountDownLatch(1); + + Thread replacingThread = new Thread(() -> { + try { + barrier.await(); + // Replace the value while the other thread is handling exception + store.put("ns", "key2", "replaced-value"); + latch.countDown(); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + + Thread computingThread = new Thread(() -> { + try { + barrier.await(); + store.computeIfAbsent("ns", "key2", __ -> { + latch.countDown(); + throw new RuntimeException("Computation failed"); + }); + } + catch (RuntimeException e) { + // Expected + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + + replacingThread.start(); + computingThread.start(); + + barrier.countDown(); + + try { + latch.await(2, TimeUnit.SECONDS); + replacingThread.join(1000); + computingThread.join(1000); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + // The replaced value should remain even though computeIfAbsent failed + assertEquals("replaced-value", store.get("ns", "key2")); + } + + @Test + void computeIfAbsentWithMultipleExceptionsAndConcurrentModifications() throws Exception { + var store = new NamespacedHierarchicalStore(null); + var exceptionCount = new AtomicInteger(0); + var successfulComputeCount = new AtomicInteger(0); + int threadCount = 3; + + List threads = new ArrayList<>(); + for (int i = 0; i < threadCount; i++) { + threads.add(new Thread(() -> { + try { + store.computeIfAbsent("ns", "key", __ -> { + int attempts = exceptionCount.incrementAndGet(); + if (attempts <= 2) { + throw new RuntimeException("Attempt " + attempts + " failed"); + } + successfulComputeCount.incrementAndGet(); + return "finally-successful"; + }); + } + catch (RuntimeException e) { + // Expected for first two attempts + } + })); + } + + threads.forEach(Thread::start); + for (Thread thread : threads) { + thread.join(1000); + } + + // Wait a bit for any pending operations + Thread.sleep(100); + + // One thread should have succeeded eventually + assertEquals(1, successfulComputeCount.get()); + + // All threads should see the successful value + var finalResult = store.computeIfAbsent("ns", "key", __ -> { + throw new RuntimeException("Should not be called - value should exist"); + }); + + assertEquals("finally-successful", finalResult); + } + + } + }