diff --git a/documentation/modules/ROOT/partials/release-notes/release-notes-6.0.2.adoc b/documentation/modules/ROOT/partials/release-notes/release-notes-6.0.2.adoc index 74a9291a1a3f..6bf0deb5d10f 100644 --- a/documentation/modules/ROOT/partials/release-notes/release-notes-6.0.2.adoc +++ b/documentation/modules/ROOT/partials/release-notes/release-notes-6.0.2.adoc @@ -17,6 +17,14 @@ repository on GitHub. ==== Bug Fixes * Make `ConsoleLauncher` compatible with JDK 26 by avoiding final field mutations. +* Fix a concurrency issue in `NamespacedHierarchicalStore.computeIfAbsent` where +`defaultCreator` was executed while holding the internal map's bucket lock, +causing threads accessing different keys in the same hash bucket to block each +other during parallel test execution +(see link:{junit-framework-repo}+/issues/5171+[issue #5171] and +link:https://github.com/assertj/assertj/issues/1996[assertj/assertj#1996]). +* `NamespacedHierarchicalStore.computeIfAbsent` no longer deadlocks when +`defaultCreator` accesses other keys that collide in the same hash bucket. [[v6.0.2-junit-platform-deprecations-and-breaking-changes]] ==== Deprecations and Breaking Changes @@ -36,7 +44,7 @@ repository on GitHub. ==== Bug Fixes * Allow using `@ResourceLock` on classes annotated with `@ClassTemplate` (or - `@ParameterizedClass`). +`@ParameterizedClass`). [[v6.0.2-junit-jupiter-deprecations-and-breaking-changes]] ==== Deprecations and Breaking Changes diff --git a/junit-platform-engine/src/main/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStore.java b/junit-platform-engine/src/main/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStore.java index baeeaed7f122..3dffc500a4eb 100644 --- a/junit-platform-engine/src/main/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStore.java +++ b/junit-platform-engine/src/main/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStore.java @@ -25,6 +25,8 @@ import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.FutureTask; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import java.util.function.Supplier; @@ -241,24 +243,78 @@ public void close() { public Object computeIfAbsent(N namespace, K key, Function defaultCreator) { Preconditions.notNull(defaultCreator, "defaultCreator must not be null"); CompositeKey compositeKey = new CompositeKey<>(namespace, key); - StoredValue storedValue = getStoredValue(compositeKey); - var result = StoredValue.evaluateIfNotNull(storedValue); - if (result == null) { - StoredValue newStoredValue = this.storedValues.compute(compositeKey, (__, oldStoredValue) -> { - if (StoredValue.evaluateIfNotNull(oldStoredValue) == null) { - rejectIfClosed(); - var computedValue = Preconditions.notNull(defaultCreator.apply(key), - "defaultCreator must not return null"); - return newStoredValue(() -> { - rejectIfClosed(); - return computedValue; - }); + + // CAS-retry loop: retry if another thread concurrently modifies the entry + for (;;) { + StoredValue localStoredValue = this.storedValues.get(compositeKey); + if (localStoredValue != null) { + Object localValue = evaluateForComputeIfAbsent(compositeKey, localStoredValue); + if (localValue != null) { + return localValue; + } + + Object computed = computeAndInstall(compositeKey, localStoredValue, key, defaultCreator); + if (computed != null) { + return computed; + } + continue; + } + + // No local mapping: consult parent first. + if (this.parentStore != null) { + StoredValue parentStoredValue = this.parentStore.getStoredValue(compositeKey); + Object parentValue = StoredValue.evaluateIfNotNull(parentStoredValue); + if (parentValue != null) { + return parentValue; } - return oldStoredValue; - }); - return requireNonNull(newStoredValue.evaluate()); + } + + Object computed = computeAndInstall(compositeKey, null, key, defaultCreator); + if (computed != null) { + return computed; + } + } + } + + private @Nullable Object evaluateForComputeIfAbsent(CompositeKey compositeKey, StoredValue storedValue) { + Supplier<@Nullable Object> supplier = storedValue.supplier(); + if (supplier instanceof DeferredSupplier deferred) { + deferred.run(); + try { + return deferred.getOrThrow(); + } + catch (Throwable t) { + this.storedValues.remove(compositeKey, storedValue); + throw t; + } + } + return storedValue.evaluate(); + } + + private @Nullable Object computeAndInstall(CompositeKey compositeKey, @Nullable StoredValue expectedOld, + K key, Function defaultCreator) { + + var deferred = new DeferredSupplier(() -> { + rejectIfClosed(); + return Preconditions.notNull(defaultCreator.apply(key), "defaultCreator must not return null"); + }); + StoredValue newStoredValue = newStoredValue(deferred); + + boolean installed = (expectedOld == null ? this.storedValues.putIfAbsent(compositeKey, newStoredValue) == null + : this.storedValues.replace(compositeKey, expectedOld, newStoredValue)); + + if (!installed) { + return null; + } + + deferred.run(); + try { + return requireNonNull(deferred.getOrThrow()); + } + catch (Throwable t) { + this.storedValues.remove(compositeKey, newStoredValue); + throw t; } - return result; } /** @@ -460,6 +516,66 @@ private void close(CloseAction closeAction) throws Throwable { } + /** + * Deferred computation that can be installed into the store without executing + * user code while holding internal map locks. + * + *

For {@link #get(Object, Object)}, failures are treated as logically absent + * (returning {@code null}) so exceptions are not observable via {@code get()}. + * + *

For {@link #computeIfAbsent(Object, Object, Function)}, + * {@link #getOrThrow()} rethrows the original failure. + */ + private static final class DeferredSupplier implements Supplier<@Nullable Object> { + + private final FutureTask<@Nullable Object> task; + + private DeferredSupplier(Supplier<@Nullable Object> delegate) { + this.task = new FutureTask<>(delegate::get); + } + + private void run() { + this.task.run(); + } + + @Override + public @Nullable Object get() { + try { + return this.task.get(); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw throwAsUncheckedException(e); + } + catch (ExecutionException e) { + Throwable t = e.getCause(); + if (t == null) { + t = e; + } + UnrecoverableExceptions.rethrowIfUnrecoverable(t); + return null; + } + } + + private @Nullable Object getOrThrow() { + try { + return this.task.get(); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw throwAsUncheckedException(e); + } + catch (ExecutionException e) { + Throwable t = e.getCause(); + if (t == null) { + t = e; + } + UnrecoverableExceptions.rethrowIfUnrecoverable(t); + throw throwAsUncheckedException(t); + } + } + } + /** * Thread-safe {@link Supplier} that memoizes the result of calling its * delegate and ensures it is called at most once. diff --git a/platform-tests/src/test/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStoreTests.java b/platform-tests/src/test/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStoreTests.java index a239db0cae2e..21e3fe567a1f 100644 --- a/platform-tests/src/test/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStoreTests.java +++ b/platform-tests/src/test/java/org/junit/platform/engine/support/store/NamespacedHierarchicalStoreTests.java @@ -26,7 +26,11 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import org.junit.jupiter.api.BeforeEach; @@ -416,6 +420,331 @@ void simulateRaceConditionInComputeIfAbsent() throws Exception { assertEquals(1, counter.get()); assertThat(values).hasSize(threads).containsOnly(1); } + + @Test + void simulateRaceConditionInComputeIfAbsentWithCollidingKeys() throws Exception { + // 20 threads: 10 will access key1, 10 will access key2 + int threads = 20; + int threadsPerKey = threads / 2; + + // Both keys have the same hashCode, forcing them into the same bucket + var key1 = new CollidingKey("k1"); + var key2 = new CollidingKey("k2"); + var chooser = new AtomicInteger(); + + // Track how many times each key's defaultCreator is invoked + var creatorCallsForKey1 = new AtomicInteger(); + var creatorCallsForKey2 = new AtomicInteger(); + + try (var localStore = new NamespacedHierarchicalStore(null)) { + executeConcurrently(threads, () -> { + // Alternate between key1 and key2 + CollidingKey key = (chooser.getAndIncrement() % 2 == 0 ? key1 : key2); + + // Each key's value is an AtomicInteger counter + AtomicInteger counter = (AtomicInteger) localStore.computeIfAbsent(namespace, key, __ -> { + if (key.equals(key1)) { + creatorCallsForKey1.incrementAndGet(); + } + else { + creatorCallsForKey2.incrementAndGet(); + } + return new AtomicInteger(); + }); + + // Each thread increments the shared counter for its key + counter.incrementAndGet(); + return 1; + }); + + assertThat(creatorCallsForKey1.get()).as( + "defaultCreator for key1 should be called exactly once").isEqualTo(1); + assertThat(creatorCallsForKey2.get()).as( + "defaultCreator for key2 should be called exactly once").isEqualTo(1); + + AtomicInteger counter1 = (AtomicInteger) requireNonNull(localStore.get(namespace, key1)); + AtomicInteger counter2 = (AtomicInteger) requireNonNull(localStore.get(namespace, key2)); + assertThat(counter1.get()).as("all %d threads for key1 should have incremented the same counter", + threadsPerKey).isEqualTo(threadsPerKey); + assertThat(counter2.get()).as("all %d threads for key2 should have incremented the same counter", + threadsPerKey).isEqualTo(threadsPerKey); + } + } + + @Test + void computeIfAbsentWithCollidingKeysDoesNotBlockConcurrentAccess() throws Exception { + try (var localStore = new NamespacedHierarchicalStore(null)) { + var key1ComputationStarted = new CountDownLatch(1); + var key2ComputationStarted = new CountDownLatch(1); + var key1Result = new AtomicReference(); + var key2Result = new AtomicReference(); + var key2WasBlocked = new AtomicBoolean(false); + + Thread thread1 = new Thread(() -> { + Object result = localStore.computeIfAbsent(namespace, new CollidingKey("key1"), __ -> { + key1ComputationStarted.countDown(); + try { + // Wait to ensure thread2 has a chance to start its computation + if (!key2ComputationStarted.await(500, TimeUnit.MILLISECONDS)) { + key2WasBlocked.set(true); + } + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return "value1"; + }); + key1Result.set(result); + }); + + Thread thread2 = new Thread(() -> { + try { + key1ComputationStarted.await(1, TimeUnit.SECONDS); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } + Object result = localStore.computeIfAbsent(namespace, new CollidingKey("key2"), __ -> { + key2ComputationStarted.countDown(); + return "value2"; + }); + key2Result.set(result); + }); + + thread1.start(); + thread2.start(); + thread1.join(2000); + thread2.join(2000); + + assertThat(key1Result.get()).as("key1 result").isEqualTo("value1"); + assertThat(key2Result.get()).as("key2 result").isEqualTo("value2"); + assertThat(key2WasBlocked).as( + "computeIfAbsent for key2 should not be blocked by key1's defaultCreator").isFalse(); + } + } + + @SuppressWarnings("deprecation") + @Test + void getOrComputeIfAbsentDoesNotDeadlockWithCollidingKeys() throws Exception { + try (var localStore = new NamespacedHierarchicalStore(null)) { + var firstComputationStarted = new CountDownLatch(1); + var secondComputationAllowedToFinish = new CountDownLatch(1); + var firstThreadTimedOut = new AtomicBoolean(false); + + Thread first = new Thread( + () -> localStore.getOrComputeIfAbsent(namespace, new CollidingKey("k1"), __ -> { + firstComputationStarted.countDown(); + try { + if (!secondComputationAllowedToFinish.await(200, TimeUnit.MILLISECONDS)) { + firstThreadTimedOut.set(true); + } + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return "value1"; + })); + + Thread second = new Thread(() -> { + try { + firstComputationStarted.await(); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + localStore.getOrComputeIfAbsent(namespace, new CollidingKey("k2"), __ -> { + secondComputationAllowedToFinish.countDown(); + return "value2"; + }); + }); + + first.start(); + second.start(); + + first.join(1000); + second.join(1000); + + assertThat(firstThreadTimedOut).as( + "getOrComputeIfAbsent should not block subsequent computations on colliding keys").isFalse(); + } + } + + @Test + void computeIfAbsentDoesNotDeadlockWithCollidingKeys() throws Exception { + try (var localStore = new NamespacedHierarchicalStore(null)) { + var firstComputationStarted = new CountDownLatch(1); + var secondComputationAllowedToFinish = new CountDownLatch(1); + var firstThreadTimedOut = new AtomicBoolean(false); + + Thread first = new Thread(() -> localStore.computeIfAbsent(namespace, new CollidingKey("k1"), __ -> { + firstComputationStarted.countDown(); + try { + if (!secondComputationAllowedToFinish.await(200, TimeUnit.MILLISECONDS)) { + firstThreadTimedOut.set(true); + } + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return "value1"; + })); + + Thread second = new Thread(() -> { + try { + firstComputationStarted.await(); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + localStore.computeIfAbsent(namespace, new CollidingKey("k2"), __ -> { + secondComputationAllowedToFinish.countDown(); + return "value2"; + }); + }); + + first.start(); + second.start(); + + first.join(1000); + second.join(1000); + + assertThat(firstThreadTimedOut).as( + "computeIfAbsent should not block subsequent computations on colliding keys").isFalse(); + } + } + + @Test + void getDoesNotSeeTransientExceptionFromComputeIfAbsent() throws Exception { + try (var localStore = new NamespacedHierarchicalStore(null)) { + var computeStarted = new CountDownLatch(1); + var getCanProceed = new CountDownLatch(1); + var computeCanThrow = new CountDownLatch(1); + var exceptionSeenByGet = new AtomicBoolean(false); + var getReturnedNull = new AtomicBoolean(false); + + Thread computeThread = new Thread(() -> { + try { + localStore.computeIfAbsent(namespace, key, __ -> { + computeStarted.countDown(); + try { + // Wait for the get thread to be ready + computeCanThrow.await(1, TimeUnit.SECONDS); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + throw new RuntimeException("boom"); + }); + } + catch (RuntimeException expected) { + // Expected - the exception should propagate back to this thread + } + finally { + getCanProceed.countDown(); + } + }); + + Thread getThread = new Thread(() -> { + try { + computeStarted.await(1, TimeUnit.SECONDS); + // Signal compute thread to throw + computeCanThrow.countDown(); + // Wait a brief moment for compute to throw and remove the entry + getCanProceed.await(1, TimeUnit.SECONDS); + // Now try to get the value + Object result = localStore.get(namespace, key); + if (result == null) { + getReturnedNull.set(true); + } + } + catch (RuntimeException e) { + // If we see the exception, that's the bug we're testing for + exceptionSeenByGet.set(true); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + + computeThread.start(); + getThread.start(); + + computeThread.join(2000); + getThread.join(2000); + + assertThat(exceptionSeenByGet).as( + "get() should not see transient exception from failed computeIfAbsent").isFalse(); + assertThat(getReturnedNull).as( + "get() should return null after computeIfAbsent fails and removes entry").isTrue(); + } + } + + @Test + void getConcurrentWithFailingComputeIfAbsentDoesNotSeeException() throws Exception { + int iterations = 100; + for (int i = 0; i < iterations; i++) { + try (var localStore = new NamespacedHierarchicalStore(null)) { + var computeStarted = new CountDownLatch(1); + var exceptionSeenByGet = new AtomicBoolean(false); + + Thread computeThread = new Thread(() -> { + try { + localStore.computeIfAbsent(namespace, key, __ -> { + computeStarted.countDown(); + throw new RuntimeException("boom"); + }); + } + catch (RuntimeException expected) { + // Expected + } + }); + + Thread getThread = new Thread(() -> { + try { + computeStarted.await(100, TimeUnit.MILLISECONDS); + // Try to observe the transient state + localStore.get(namespace, key); + } + catch (RuntimeException e) { + exceptionSeenByGet.set(true); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + + computeThread.start(); + getThread.start(); + + computeThread.join(500); + getThread.join(500); + + assertThat(exceptionSeenByGet).as( + "get() should not see transient exception from failed computeIfAbsent (iteration %d)", + i).isFalse(); + } + } + } + + @Test + void computeIfAbsentOverridesParentNullValue() { + // computeIfAbsent must treat a null value from the parent store as logically absent, + // so the child store can install and keep its own non-null value for the same key. + try (var parent = new NamespacedHierarchicalStore(null); + var child = new NamespacedHierarchicalStore(parent)) { + + parent.put(namespace, key, null); + + assertNull(parent.get(namespace, key)); + assertNull(child.get(namespace, key)); + + Object childValue = child.computeIfAbsent(namespace, key, __ -> "value"); + + assertEquals("value", childValue); + assertEquals("value", child.get(namespace, key)); + } + } } @Nested @@ -663,6 +992,36 @@ private void assertClosed() { } + private static final class CollidingKey { + + private final String value; + + private CollidingKey(String value) { + this.value = value; + } + + @Override + public int hashCode() { + return 42; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof CollidingKey other)) { + return false; + } + return this.value.equals(other.value); + } + + @Override + public String toString() { + return this.value; + } + } + private static Object createObject(String display) { return new Object() {