Skip to content

Commit f3c38bb

Browse files
Fix workflow freezing thread safety.
1 parent daff58b commit f3c38bb

File tree

7 files changed

+179
-9
lines changed

7 files changed

+179
-9
lines changed

workflow-runtime/src/appleMain/kotlin/com/squareup/workflow1/internal/Synchronization.apple.kt

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
package com.squareup.workflow1.internal
22

3+
import kotlinx.cinterop.CPointer
4+
import kotlinx.cinterop.ExperimentalForeignApi
5+
import platform.Foundation.NSCopyingProtocol
36
import platform.Foundation.NSLock
7+
import platform.Foundation.NSThread
8+
import platform.Foundation.NSZone
9+
import platform.darwin.NSObject
410

511
internal actual typealias Lock = NSLock
612

@@ -12,3 +18,35 @@ internal actual inline fun <R> Lock.withLock(block: () -> R): R {
1218
unlock()
1319
}
1420
}
21+
22+
/**
23+
* Implementation of [ThreadLocal] that works in a similar way to Java's, based on a thread-specific
24+
* map/dictionary.
25+
*/
26+
internal actual class ThreadLocal<T>(
27+
private val initialValue: () -> T
28+
) : NSObject(), NSCopyingProtocol {
29+
30+
private val threadDictionary
31+
get() = NSThread.currentThread().threadDictionary
32+
33+
actual fun get(): T {
34+
@Suppress("UNCHECKED_CAST")
35+
return (threadDictionary.objectForKey(aKey = this) as T?)
36+
?: initialValue().also(::set)
37+
}
38+
39+
actual fun set(value: T) {
40+
threadDictionary.setObject(value, forKey = this)
41+
}
42+
43+
/**
44+
* [Docs](https://developer.apple.com/documentation/foundation/nscopying/copy(with:)) say [zone]
45+
* is unused.
46+
*/
47+
@OptIn(ExperimentalForeignApi::class)
48+
override fun copyWithZone(zone: CPointer<NSZone>?): Any = this
49+
}
50+
51+
internal actual fun <T> threadLocalOf(initialValue: () -> T): ThreadLocal<T> =
52+
ThreadLocal(initialValue)

workflow-runtime/src/commonMain/kotlin/com/squareup/workflow1/internal/RealRenderContext.kt

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,12 @@ internal class RealRenderContext<PropsT, StateT, OutputT>(
5151
* Used to:
5252
* - prevent modifications to this object after [freeze] is called.
5353
* - prevent sending to sinks before render returns.
54+
*
55+
* This is a [ThreadLocal] since we only care about preventing calls during rendering from the
56+
* thread that is actually doing the rendering. If a background thread happens to send something
57+
* into the sink, for example, while the main thread is rendering, it's not a violation.
5458
*/
55-
private var frozen = false
59+
private var frozen by threadLocalOf { true }
5660

5761
override val actionSink: Sink<WorkflowAction<PropsT, StateT, OutputT>> get() = this
5862

@@ -100,7 +104,6 @@ internal class RealRenderContext<PropsT, StateT, OutputT>(
100104
* Freezes this context so that any further calls to this context will throw.
101105
*/
102106
fun freeze() {
103-
checkNotFrozen("freeze") { "freeze" }
104107
frozen = true
105108
}
106109

@@ -117,8 +120,10 @@ internal class RealRenderContext<PropsT, StateT, OutputT>(
117120
*
118121
* @see checkWithKey
119122
*/
120-
private inline fun checkNotFrozen(stackTraceKey: Any, lazyMessage: () -> Any) =
121-
checkWithKey(!frozen, stackTraceKey) {
122-
"RenderContext cannot be used after render method returns: ${lazyMessage()}"
123-
}
123+
private inline fun checkNotFrozen(
124+
stackTraceKey: Any,
125+
lazyMessage: () -> Any
126+
) = checkWithKey(!frozen, stackTraceKey) {
127+
"RenderContext cannot be used after render method returns: ${lazyMessage()}"
128+
}
124129
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,29 @@
11
package com.squareup.workflow1.internal
22

3+
import kotlin.reflect.KProperty
4+
35
internal expect class Lock()
46

57
internal expect inline fun <R> Lock.withLock(block: () -> R): R
8+
9+
internal expect class ThreadLocal<T> {
10+
fun get(): T
11+
fun set(value: T)
12+
}
13+
14+
internal expect fun <T> threadLocalOf(initialValue: () -> T): ThreadLocal<T>
15+
16+
@Suppress("NOTHING_TO_INLINE")
17+
internal inline operator fun <T> ThreadLocal<T>.getValue(
18+
receiver: Any?,
19+
property: KProperty<*>
20+
): T = get()
21+
22+
@Suppress("NOTHING_TO_INLINE")
23+
internal inline operator fun <T> ThreadLocal<T>.setValue(
24+
receiver: Any?,
25+
property: KProperty<*>,
26+
value: T
27+
) {
28+
set(value)
29+
}

workflow-runtime/src/commonTest/kotlin/com/squareup/workflow1/internal/RealRenderContextTest.kt

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,10 @@ internal class RealRenderContextTest {
220220

221221
val child = Workflow.stateless<Unit, Nothing, Unit> { fail() }
222222
assertFailsWith<IllegalStateException> { context.renderChild(child) }
223-
assertFailsWith<IllegalStateException> { context.freeze() }
224223
assertFailsWith<IllegalStateException> { context.remember("key", typeOf<String>()) {} }
224+
225+
// Freeze is the exception, it's idempotent and can be called again.
226+
context.freeze()
225227
}
226228

227229
private fun createdPoisonedContext(): RealRenderContext<String, String, String> {
@@ -234,7 +236,9 @@ internal class RealRenderContextTest {
234236
eventActionsChannel,
235237
workflowTracer = null,
236238
runtimeConfig = emptySet(),
237-
)
239+
).apply {
240+
unfreeze()
241+
}
238242
}
239243

240244
private fun createTestContext(): RealRenderContext<String, String, String> {
@@ -247,6 +251,8 @@ internal class RealRenderContextTest {
247251
eventActionsChannel,
248252
workflowTracer = null,
249253
runtimeConfig = emptySet(),
250-
)
254+
).apply {
255+
unfreeze()
256+
}
251257
}
252258
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package com.squareup.workflow1.internal
2+
3+
import platform.Foundation.NSLock
4+
import platform.Foundation.NSThread
5+
import kotlin.concurrent.Volatile
6+
import kotlin.test.Test
7+
import kotlin.test.assertEquals
8+
9+
class ThreadLocalTest {
10+
11+
@Volatile
12+
private var valueFromThread: Int = -1
13+
14+
@Test fun initialValue() {
15+
val threadLocal = ThreadLocal(initialValue = { 42 })
16+
assertEquals(42, threadLocal.get())
17+
}
18+
19+
@Test fun settingValue() {
20+
val threadLocal = ThreadLocal(initialValue = { 42 })
21+
threadLocal.set(0)
22+
assertEquals(0, threadLocal.get())
23+
}
24+
25+
@Test fun initialValue_inSeparateThread_afterChanging() {
26+
val threadLocal = ThreadLocal(initialValue = { 42 })
27+
threadLocal.set(0)
28+
29+
val thread = NSThread {
30+
valueFromThread = threadLocal.get()
31+
}
32+
thread.start()
33+
thread.join()
34+
35+
assertEquals(42, valueFromThread)
36+
}
37+
38+
@Test fun set_fromDifferentThreads_doNotConflict() {
39+
val threadLocal = ThreadLocal(initialValue = { 0 })
40+
val threadStartedLatch = NSLock().apply { lock() }
41+
val firstReadLatch = NSLock().apply { lock() }
42+
val firstReadDoneLatch = NSLock().apply { lock() }
43+
val secondReadLatch = NSLock().apply { lock() }
44+
45+
val thread = NSThread {
46+
threadStartedLatch.unlock()
47+
firstReadLatch.lock()
48+
threadLocal.set(1)
49+
valueFromThread = threadLocal.get()
50+
firstReadDoneLatch.unlock()
51+
secondReadLatch.lock()
52+
valueFromThread = threadLocal.get()
53+
}
54+
thread.start()
55+
56+
// Wait for the other thread to start, then both threads set the value to something different
57+
// at the same time.
58+
threadStartedLatch.lock()
59+
firstReadLatch.unlock()
60+
threadLocal.set(2)
61+
62+
// Wait for the background thread to finish setting value, then ensure that both threads see
63+
// independent values.
64+
firstReadDoneLatch.lock()
65+
assertEquals(1, valueFromThread)
66+
assertEquals(2, threadLocal.get())
67+
68+
// Change the value in this thread then read it again from the background thread.
69+
threadLocal.set(3)
70+
secondReadLatch.unlock()
71+
thread.join()
72+
assertEquals(1, valueFromThread)
73+
}
74+
75+
private fun NSThread.join() {
76+
while (!isFinished()) {
77+
// Avoid being optimized out.
78+
// Time interval is in seconds.
79+
NSThread.sleepForTimeInterval(1.0 / 1000)
80+
}
81+
}
82+
}

workflow-runtime/src/jsMain/kotlin/com/squareup/workflow1/internal/Synchronization.js.kt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,13 @@ package com.squareup.workflow1.internal
55
internal actual typealias Lock = Any
66

77
internal actual inline fun <R> Lock.withLock(block: () -> R): R = block()
8+
9+
internal actual class ThreadLocal<T>(private var value: T) {
10+
actual fun get(): T = value
11+
actual fun set(value: T) {
12+
this.value = value
13+
}
14+
}
15+
16+
internal actual fun <T> threadLocalOf(initialValue: () -> T): ThreadLocal<T> =
17+
ThreadLocal(initialValue())

workflow-runtime/src/jvmMain/kotlin/com/squareup/workflow1/internal/Synchronization.jvm.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,8 @@ package com.squareup.workflow1.internal
33
internal actual typealias Lock = Any
44

55
internal actual inline fun <R> Lock.withLock(block: () -> R): R = synchronized(this, block)
6+
7+
internal actual typealias ThreadLocal<T> = java.lang.ThreadLocal<T>
8+
9+
internal actual fun <T> threadLocalOf(initialValue: () -> T): ThreadLocal<T> =
10+
ThreadLocal.withInitial(initialValue)

0 commit comments

Comments
 (0)