Skip to content

Commit f52ece5

Browse files
vrozovcloud-fan
authored andcommitted
[SPARK-51821][CORE] Call interrupt() without holding uninterruptibleLock to avoid possible deadlock
### What changes were proposed in this pull request? Do not hold `uninterruptibleLock` monitor while calling `super.interrupt()` in `UninterruptibleThread`, instead use newly introduced `awaitInterruptThread` flag and wait for `super.interrupt()` to be called. ### Why are the changes needed? There is potential deadlock as `UninterruptibleThread` may be blocked on NIO operation and interrupting channel while holding `uninterruptibleLock` monitor may cause deadlock like in ``` Found one Java-level deadlock: ============================= "pool-1-thread-1-ScalaTest-running-UninterruptibleThreadSuite": waiting to lock monitor 0x00006000036ee3c0 (object 0x000000070f3019d0, a java.lang.Object), which is held by "task thread" "task thread": waiting to lock monitor 0x00006000036e75a0 (object 0x000000070f70fe80, a java.lang.Object), which is held by "pool-1-thread-1-ScalaTest-running-UninterruptibleThreadSuite" Java stack information for the threads listed above: =================================================== "pool-1-thread-1-ScalaTest-running-UninterruptibleThreadSuite": at java.nio.channels.spi.AbstractInterruptibleChannel$1.interrupt(java.base17.0.14/AbstractInterruptibleChannel.java:157) - waiting to lock <0x000000070f3019d0> (a java.lang.Object) at java.lang.Thread.interrupt(java.base17.0.14/Thread.java:1004) - locked <0x000000070f70fc90> (a java.lang.Object) at org.apache.spark.util.UninterruptibleThread.interrupt(UninterruptibleThread.scala:99) - locked <0x000000070f70fe80> (a java.lang.Object) at org.apache.spark.util.UninterruptibleThreadSuite.$anonfun$new$5(UninterruptibleThreadSuite.scala:159) - locked <0x000000070f70f9f8> (a java.lang.Object) at org.apache.spark.util.UninterruptibleThreadSuite$$Lambda$216/0x000000700120d6c8.apply$mcV$sp(Unknown Source) at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.scala:18) at org.scalatest.enablers.Timed$$anon$1.timeoutAfter(Timed.scala:127) at org.scalatest.concurrent.TimeLimits$.failAfterImpl(TimeLimits.scala:282) at org.scalatest.concurrent.TimeLimits.failAfter(TimeLimits.scala:231) at org.scalatest.concurrent.TimeLimits.failAfter$(TimeLimits.scala:230) at org.apache.spark.SparkFunSuite.failAfter(SparkFunSuite.scala:69) at org.apache.spark.SparkFunSuite.$anonfun$test$2(SparkFunSuite.scala:155) at org.apache.spark.SparkFunSuite$$Lambda$205/0x0000007001207700.apply(Unknown Source) at org.scalatest.OutcomeOf.outcomeOf(OutcomeOf.scala:85) at org.scalatest.OutcomeOf.outcomeOf$(OutcomeOf.scala:83) at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) at org.scalatest.Transformer.apply(Transformer.scala:22) at org.scalatest.Transformer.apply(Transformer.scala:20) at org.scalatest.funsuite.AnyFunSuiteLike$$anon$1.apply(AnyFunSuiteLike.scala:226) at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:227) at org.scalatest.funsuite.AnyFunSuiteLike.invokeWithFixture$1(AnyFunSuiteLike.scala:224) at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$runTest$1(AnyFunSuiteLike.scala:236) at org.scalatest.funsuite.AnyFunSuiteLike$$Lambda$343/0x00000070012867b0.apply(Unknown Source) at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306) at org.scalatest.funsuite.AnyFunSuiteLike.runTest(AnyFunSuiteLike.scala:236) at org.scalatest.funsuite.AnyFunSuiteLike.runTest$(AnyFunSuiteLike.scala:218) at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterEach$$super$runTest(SparkFunSuite.scala:69) at org.scalatest.BeforeAndAfterEach.runTest(BeforeAndAfterEach.scala:234) at org.scalatest.BeforeAndAfterEach.runTest$(BeforeAndAfterEach.scala:227) at org.apache.spark.SparkFunSuite.runTest(SparkFunSuite.scala:69) at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$runTests$1(AnyFunSuiteLike.scala:269) at org.scalatest.funsuite.AnyFunSuiteLike$$Lambda$339/0x00000070012833e0.apply(Unknown Source) at org.scalatest.SuperEngine.$anonfun$runTestsInBranch$1(Engine.scala:413) at org.scalatest.SuperEngine$$Lambda$340/0x0000007001283998.apply(Unknown Source) at scala.collection.immutable.List.foreach(List.scala:334) at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:401) at org.scalatest.SuperEngine.runTestsInBranch(Engine.scala:396) at org.scalatest.SuperEngine.runTestsImpl(Engine.scala:475) at org.scalatest.funsuite.AnyFunSuiteLike.runTests(AnyFunSuiteLike.scala:269) at org.scalatest.funsuite.AnyFunSuiteLike.runTests$(AnyFunSuiteLike.scala:268) at org.scalatest.funsuite.AnyFunSuite.runTests(AnyFunSuite.scala:1564) at org.scalatest.Suite.run(Suite.scala:1114) at org.scalatest.Suite.run$(Suite.scala:1096) at org.scalatest.funsuite.AnyFunSuite.org$scalatest$funsuite$AnyFunSuiteLike$$super$run(AnyFunSuite.scala:1564) at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$run$1(AnyFunSuiteLike.scala:273) at org.scalatest.funsuite.AnyFunSuiteLike$$Lambda$332/0x000000700127b000.apply(Unknown Source) at org.scalatest.SuperEngine.runImpl(Engine.scala:535) at org.scalatest.funsuite.AnyFunSuiteLike.run(AnyFunSuiteLike.scala:273) at org.scalatest.funsuite.AnyFunSuiteLike.run$(AnyFunSuiteLike.scala:272) at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterAll$$super$run(SparkFunSuite.scala:69) at org.scalatest.BeforeAndAfterAll.liftedTree1$1(BeforeAndAfterAll.scala:213) at org.scalatest.BeforeAndAfterAll.run(BeforeAndAfterAll.scala:210) at org.scalatest.BeforeAndAfterAll.run$(BeforeAndAfterAll.scala:208) at org.apache.spark.SparkFunSuite.run(SparkFunSuite.scala:69) at org.scalatest.tools.Framework.org$scalatest$tools$Framework$$runSuite(Framework.scala:321) at org.scalatest.tools.Framework$ScalaTestTask.execute(Framework.scala:517) at sbt.ForkMain$Run.lambda$runTest$1(ForkMain.java:414) at sbt.ForkMain$Run$$Lambda$107/0x0000007001110000.call(Unknown Source) at java.util.concurrent.FutureTask.run(java.base17.0.14/FutureTask.java:264) at java.util.concurrent.ThreadPoolExecutor.runWorker(java.base17.0.14/ThreadPoolExecutor.java:1136) at java.util.concurrent.ThreadPoolExecutor$Worker.run(java.base17.0.14/ThreadPoolExecutor.java:635) at java.lang.Thread.run(java.base17.0.14/Thread.java:840) "task thread": at org.apache.spark.util.UninterruptibleThread.interrupt(UninterruptibleThread.scala:96) - waiting to lock <0x000000070f70fe80> (a java.lang.Object) at org.apache.spark.util.UninterruptibleThreadSuite$InterruptibleChannel.implCloseChannel(UninterruptibleThreadSuite.scala:143) at java.nio.channels.spi.AbstractInterruptibleChannel.close(java.base17.0.14/AbstractInterruptibleChannel.java:112) - locked <0x000000070f3019d0> (a java.lang.Object) at org.apache.spark.util.UninterruptibleThreadSuite$InterruptibleChannel.<init>(UninterruptibleThreadSuite.scala:138) at org.apache.spark.util.UninterruptibleThreadSuite$$anon$5.run(UninterruptibleThreadSuite.scala:153) Found 1 deadlock. ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added 2 new test cases to the `UninterruptibleThreadSuite` ### Was this patch authored or co-authored using generative AI tooling? No Closes #50810 from vrozov/SPARK-51821. Authored-by: Vlad Rozov <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 38258a5 commit f52ece5

File tree

2 files changed

+150
-33
lines changed

2 files changed

+150
-33
lines changed

core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala

Lines changed: 95 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,90 @@ private[spark] class UninterruptibleThread(
3535
this(null, name)
3636
}
3737

38-
/** A monitor to protect "uninterruptible" and "interrupted" */
39-
private val uninterruptibleLock = new Object
38+
private class UninterruptibleLock {
39+
/**
40+
* Indicates if `this` thread are in the uninterruptible status. If so, interrupting
41+
* "this" will be deferred until `this` enters into the interruptible status.
42+
*/
43+
@GuardedBy("uninterruptibleLock")
44+
private var uninterruptible = false
4045

41-
/**
42-
* Indicates if `this` thread are in the uninterruptible status. If so, interrupting
43-
* "this" will be deferred until `this` enters into the interruptible status.
44-
*/
45-
@GuardedBy("uninterruptibleLock")
46-
private var uninterruptible = false
46+
/**
47+
* Indicates if we should interrupt `this` when we are leaving the uninterruptible zone.
48+
*/
49+
@GuardedBy("uninterruptibleLock")
50+
private var shouldInterruptThread = false
4751

48-
/**
49-
* Indicates if we should interrupt `this` when we are leaving the uninterruptible zone.
50-
*/
51-
@GuardedBy("uninterruptibleLock")
52-
private var shouldInterruptThread = false
52+
/**
53+
* Indicates that we should wait for interrupt() call before proceeding.
54+
*/
55+
@GuardedBy("uninterruptibleLock")
56+
private var awaitInterruptThread = false
57+
58+
/**
59+
* Set [[uninterruptible]] to given value and returns the previous value.
60+
*/
61+
def getAndSetUninterruptible(value: Boolean): Boolean = synchronized {
62+
val uninterruptible = this.uninterruptible
63+
this.uninterruptible = value
64+
uninterruptible
65+
}
66+
67+
def setShouldInterruptThread(value: Boolean): Unit = synchronized {
68+
shouldInterruptThread = value
69+
}
70+
71+
def setAwaitInterruptThread(value: Boolean): Unit = synchronized {
72+
awaitInterruptThread = value
73+
}
74+
75+
/**
76+
* Is call to [[java.lang.Thread.interrupt()]] pending
77+
*/
78+
def isInterruptPending: Boolean = synchronized {
79+
// Clear the interrupted status if it's set.
80+
shouldInterruptThread = Thread.interrupted() || shouldInterruptThread
81+
// wait for super.interrupt() to be called
82+
!shouldInterruptThread && awaitInterruptThread
83+
}
84+
85+
/**
86+
* Set [[uninterruptible]] back to false and call [[java.lang.Thread.interrupt()]] to
87+
* recover interrupt state if necessary
88+
*/
89+
def recoverInterrupt(): Unit = synchronized {
90+
uninterruptible = false
91+
if (shouldInterruptThread) {
92+
shouldInterruptThread = false
93+
// Recover the interrupted status
94+
UninterruptibleThread.super.interrupt()
95+
}
96+
}
97+
98+
/**
99+
* Is it safe to call [[java.lang.Thread.interrupt()]] and interrupt the current thread
100+
* @return true when there is no concurrent [[runUninterruptibly()]] call ([[uninterruptible]]
101+
* is true) and no concurrent [[interrupt()]] call, otherwise false
102+
*/
103+
def isInterruptible: Boolean = synchronized {
104+
shouldInterruptThread = uninterruptible
105+
// as we are releasing uninterruptibleLock before calling super.interrupt() there is a
106+
// possibility that runUninterruptibly() would be called after lock is released but before
107+
// super.interrupt() is called. In this case to prevent runUninterruptibly() from being
108+
// interrupted, we use awaitInterruptThread flag. We need to set it only if
109+
// runUninterruptibly() is not yet set uninterruptible to true (!shouldInterruptThread) and
110+
// there is no other threads that called interrupt (awaitInterruptThread is already true)
111+
if (!shouldInterruptThread && !awaitInterruptThread) {
112+
awaitInterruptThread = true
113+
true
114+
} else {
115+
false
116+
}
117+
}
118+
}
119+
120+
/** A monitor to protect "uninterruptible" and "interrupted" */
121+
private val uninterruptibleLock = new UninterruptibleLock
53122

54123
/**
55124
* Run `f` uninterruptibly in `this` thread. The thread won't be interrupted before returning
@@ -63,27 +132,23 @@ private[spark] class UninterruptibleThread(
63132
s"Expected: $this but was ${Thread.currentThread()}")
64133
}
65134

66-
if (uninterruptibleLock.synchronized { uninterruptible }) {
135+
if (uninterruptibleLock.getAndSetUninterruptible(true)) {
67136
// We are already in the uninterruptible status. So just run "f" and return
68137
return f
69138
}
70139

71-
uninterruptibleLock.synchronized {
72-
// Clear the interrupted status if it's set.
73-
shouldInterruptThread = Thread.interrupted() || shouldInterruptThread
74-
uninterruptible = true
140+
while (uninterruptibleLock.isInterruptPending) {
141+
try {
142+
Thread.sleep(100)
143+
} catch {
144+
case _: InterruptedException => uninterruptibleLock.setShouldInterruptThread(true)
145+
}
75146
}
147+
76148
try {
77149
f
78150
} finally {
79-
uninterruptibleLock.synchronized {
80-
uninterruptible = false
81-
if (shouldInterruptThread) {
82-
// Recover the interrupted status
83-
super.interrupt()
84-
shouldInterruptThread = false
85-
}
86-
}
151+
uninterruptibleLock.recoverInterrupt()
87152
}
88153
}
89154

@@ -92,11 +157,11 @@ private[spark] class UninterruptibleThread(
92157
* interrupted until it enters into the interruptible status.
93158
*/
94159
override def interrupt(): Unit = {
95-
uninterruptibleLock.synchronized {
96-
if (uninterruptible) {
97-
shouldInterruptThread = true
98-
} else {
160+
if (uninterruptibleLock.isInterruptible) {
161+
try {
99162
super.interrupt()
163+
} finally {
164+
uninterruptibleLock.setAwaitInterruptThread(false)
100165
}
101166
}
102167
}

core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.util
1919

20+
import java.nio.channels.spi.AbstractInterruptibleChannel
2021
import java.util.concurrent.{CountDownLatch, TimeUnit}
2122

2223
import scala.util.Random
@@ -115,6 +116,46 @@ class UninterruptibleThreadSuite extends SparkFunSuite {
115116
assert(interruptStatusBeforeExit)
116117
}
117118

119+
test("no runUninterruptibly") {
120+
@volatile var hasInterruptedException = false
121+
val latch = new CountDownLatch(1)
122+
val t = new UninterruptibleThread("test") {
123+
override def run(): Unit = {
124+
latch.countDown()
125+
hasInterruptedException = sleep(1)
126+
}
127+
}
128+
t.start()
129+
latch.await(10, TimeUnit.SECONDS)
130+
t.interrupt()
131+
t.join()
132+
assert(hasInterruptedException === true)
133+
}
134+
135+
test("SPARK-51821 uninterruptibleLock deadlock") {
136+
val latch = new CountDownLatch(1)
137+
val task = new UninterruptibleThread("task thread") {
138+
override def run(): Unit = {
139+
val channel = new AbstractInterruptibleChannel() {
140+
override def implCloseChannel(): Unit = {
141+
begin()
142+
latch.countDown()
143+
try {
144+
Thread.sleep(Long.MaxValue)
145+
} catch {
146+
case _: InterruptedException => Thread.currentThread().interrupt()
147+
}
148+
}
149+
}
150+
channel.close()
151+
}
152+
}
153+
task.start()
154+
assert(latch.await(10, TimeUnit.SECONDS), "await timeout")
155+
task.interrupt()
156+
task.join()
157+
}
158+
118159
test("stress test") {
119160
@volatile var hasInterruptedException = false
120161
val t = new UninterruptibleThread("test") {
@@ -148,9 +189,20 @@ class UninterruptibleThreadSuite extends SparkFunSuite {
148189
}
149190
}
150191
t.start()
151-
for (i <- 0 until 400) {
152-
Thread.sleep(Random.nextInt(10))
153-
t.interrupt()
192+
val threads = new Array[Thread](10)
193+
for (j <- 0 until 10) {
194+
threads(j) = new Thread() {
195+
override def run(): Unit = {
196+
for (i <- 0 until 400) {
197+
Thread.sleep(Random.nextInt(10))
198+
t.interrupt()
199+
}
200+
}
201+
}
202+
threads(j).start()
203+
}
204+
for (j <- 0 until 10) {
205+
threads(j).join()
154206
}
155207
t.join()
156208
assert(hasInterruptedException === false)

0 commit comments

Comments
 (0)