Skip to content

Fix workflow freezing thread safety. #1375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
package com.squareup.workflow1.internal

import kotlinx.cinterop.CPointer
import kotlinx.cinterop.ExperimentalForeignApi
import platform.Foundation.NSCopyingProtocol
import platform.Foundation.NSLock
import platform.Foundation.NSThread
import platform.Foundation.NSZone
import platform.darwin.NSObject

/**
* Creates a lock that, after locking, must only be unlocked by the thread that acquired the lock.
*
* See the docs: https://developer.apple.com/documentation/foundation/nslock#overview
*/
internal actual typealias Lock = NSLock

internal actual inline fun <R> Lock.withLock(block: () -> R): R {
Expand All @@ -12,3 +23,35 @@ internal actual inline fun <R> Lock.withLock(block: () -> R): R {
unlock()
}
}

/**
* Implementation of [ThreadLocal] that works in a similar way to Java's, based on a thread-specific
* map/dictionary.
*/
internal actual class ThreadLocal<T>(
private val initialValue: () -> T
) : NSObject(), NSCopyingProtocol {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jamieQ maybe you could give this a sanity check?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes please.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


private val threadDictionary
get() = NSThread.currentThread().threadDictionary

actual fun get(): T {
@Suppress("UNCHECKED_CAST")
return (threadDictionary.objectForKey(aKey = this) as T?)
?: initialValue().also(::set)
}

actual fun set(value: T) {
threadDictionary.setObject(value, forKey = this)
}

/**
* [Docs](https://developer.apple.com/documentation/foundation/nscopying/copy(with:)) say [zone]
* is unused.
*/
@OptIn(ExperimentalForeignApi::class)
override fun copyWithZone(zone: CPointer<NSZone>?): Any = this
}

internal actual fun <T> threadLocalOf(initialValue: () -> T): ThreadLocal<T> =
ThreadLocal(initialValue)
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,24 @@ internal class RealRenderContext<PropsT, StateT, OutputT>(
}

/**
* False during the current render call, set to true once this node is finished rendering.
* False except while this [WorkflowNode] is running the workflow's `render` method.
*
* Used to:
* - prevent modifications to this object after [freeze] is called.
* - prevent sending to sinks before render returns.
* - Prevent modifications to this object after [freeze] is called (e.g. [renderChild] calls).
* Only allowed when this flag is true.
* - Prevent sending to sinks before render returns. Only allowed when this flag is false.
*
* This is a [ThreadLocal] since we only care about preventing calls during rendering from the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we're confident that we only ever call freeze() from the render thread? I can't think of any violations.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya I can't think of RenderContext methods that would be called off main thread? will keep thinking.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add another comment to freeze saying that's a requirement, but yea we only call it from one place.

Copy link
Collaborator Author

@zach-klippenstein zach-klippenstein Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And to @steve-the-edwards reply to this comment (which doesn't show up in the thread for some reason, fucking github), it doesn't matter if we call it on the main thread or not, or even (to Ray's question) if they're called only from render, we just need to make sure freeze and unfreeze are always called symmetrically from the same thread (i.e. when one thread unfreezes, the same thread later freezes).

If arbitrary RenderContext methods are called from various threads, and that code isn't already thread-safe, then it's already broken since (1) frozen is not a thread synchronization mechanism and (2) even if it were, we mutate internal data structures outside of the unfrozen section. actionSink.send is thread-safe since it just sends to a channel. But i think we were catching that case with the default value of frozen before, which we're not now, so I need to rethink that. Maybe a separate flag for whether sending to a sink is allowed.

* thread that is actually doing the rendering. If a background thread happens to send something
* into the sink, for example, while the main thread is rendering, it's not a violation.
*/
private var frozen = false
private var performingRender by threadLocalOf { false }

override val actionSink: Sink<WorkflowAction<PropsT, StateT, OutputT>> get() = this

override fun send(value: WorkflowAction<PropsT, StateT, OutputT>) {
if (!frozen) {
// Can't send actions from render thread during render pass.
if (performingRender) {
throw UnsupportedOperationException(
"Expected sink to not be sent to until after the render pass. " +
"Received action: ${value.debuggingName}"
Expand All @@ -72,7 +78,7 @@ internal class RealRenderContext<PropsT, StateT, OutputT>(
key: String,
handler: (ChildOutputT) -> WorkflowAction<PropsT, StateT, OutputT>
): ChildRenderingT {
checkNotFrozen(child.identifier) {
checkPerformingRender(child.identifier) {
"renderChild(${child.identifier})"
}
return renderer.render(child, props, key, handler)
Expand All @@ -82,7 +88,7 @@ internal class RealRenderContext<PropsT, StateT, OutputT>(
key: String,
sideEffect: suspend CoroutineScope.() -> Unit
) {
checkNotFrozen(key) { "runningSideEffect($key)" }
checkPerformingRender(key) { "runningSideEffect($key)" }
sideEffectRunner.runningSideEffect(key, sideEffect)
}

Expand All @@ -92,23 +98,22 @@ internal class RealRenderContext<PropsT, StateT, OutputT>(
vararg inputs: Any?,
calculation: () -> ResultT
): ResultT {
checkNotFrozen(key) { "remember($key)" }
checkPerformingRender(key) { "remember($key)" }
return rememberStore.remember(key, resultType, inputs = inputs, calculation)
}

/**
* Freezes this context so that any further calls to this context will throw.
*/
fun freeze() {
checkNotFrozen("freeze") { "freeze" }
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just a defensive programming check, since this method is only called by internal code. It's more useful for it to be idempotent now.

frozen = true
performingRender = false
}

/**
* Unfreezes when the node is about to render() again.
*/
fun unfreeze() {
frozen = false
performingRender = true
}

/**
Expand All @@ -117,8 +122,10 @@ internal class RealRenderContext<PropsT, StateT, OutputT>(
*
* @see checkWithKey
*/
private inline fun checkNotFrozen(stackTraceKey: Any, lazyMessage: () -> Any) =
checkWithKey(!frozen, stackTraceKey) {
"RenderContext cannot be used after render method returns: ${lazyMessage()}"
}
private inline fun checkPerformingRender(
stackTraceKey: Any,
lazyMessage: () -> Any
) = checkWithKey(performingRender, stackTraceKey) {
"RenderContext cannot be used after render method returns: ${lazyMessage()}"
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,29 @@
package com.squareup.workflow1.internal

import kotlin.reflect.KProperty

internal expect class Lock()

internal expect inline fun <R> Lock.withLock(block: () -> R): R

internal expect class ThreadLocal<T> {
fun get(): T
fun set(value: T)
}

internal expect fun <T> threadLocalOf(initialValue: () -> T): ThreadLocal<T>

@Suppress("NOTHING_TO_INLINE")
internal inline operator fun <T> ThreadLocal<T>.getValue(
receiver: Any?,
property: KProperty<*>
): T = get()

@Suppress("NOTHING_TO_INLINE")
internal inline operator fun <T> ThreadLocal<T>.setValue(
receiver: Any?,
property: KProperty<*>,
value: T
) {
set(value)
}
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,10 @@ internal class RealRenderContextTest {

val child = Workflow.stateless<Unit, Nothing, Unit> { fail() }
assertFailsWith<IllegalStateException> { context.renderChild(child) }
assertFailsWith<IllegalStateException> { context.freeze() }
assertFailsWith<IllegalStateException> { context.remember("key", typeOf<String>()) {} }

// Freeze is the exception, it's idempotent and can be called again.
context.freeze()
}

private fun createdPoisonedContext(): RealRenderContext<String, String, String> {
Expand All @@ -234,7 +236,9 @@ internal class RealRenderContextTest {
eventActionsChannel,
workflowTracer = null,
runtimeConfig = emptySet(),
)
).apply {
unfreeze()
}
}

private fun createTestContext(): RealRenderContext<String, String, String> {
Expand All @@ -247,6 +251,8 @@ internal class RealRenderContextTest {
eventActionsChannel,
workflowTracer = null,
runtimeConfig = emptySet(),
)
).apply {
unfreeze()
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package com.squareup.workflow1.internal

import platform.Foundation.NSCondition
import platform.Foundation.NSThread
import kotlin.concurrent.Volatile
import kotlin.test.Test
import kotlin.test.assertEquals

class ThreadLocalTest {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jamieQ and whoever else to review this as well then


@Volatile
private var valueFromThread: Int = -1

@Test fun initialValue() {
val threadLocal = ThreadLocal(initialValue = { 42 })
assertEquals(42, threadLocal.get())
}

@Test fun settingValue() {
val threadLocal = ThreadLocal(initialValue = { 42 })
threadLocal.set(0)
assertEquals(0, threadLocal.get())
}

@Test fun initialValue_inSeparateThread_afterChanging() {
val threadLocal = ThreadLocal(initialValue = { 42 })
threadLocal.set(0)

val thread = NSThread {
valueFromThread = threadLocal.get()
}
thread.start()
thread.join()

assertEquals(42, valueFromThread)
}

@Test fun set_fromDifferentThreads_doNotConflict() {
val threadLocal = ThreadLocal(initialValue = { 0 })
// threadStartedLatch and firstReadLatch together form a barrier: the allow the background
// to start up and get to the same point as the test thread, just before writing to the
// ThreadLocal, before allowing both threads to perform the write as close to the same time as
// possible.
val threadStartedLatch = NSCondition()
val firstReadLatch = NSCondition()
val firstReadDoneLatch = NSCondition()
val secondReadLatch = NSCondition()

val thread = NSThread {
// Wait on the barrier to sync with the test thread.
threadStartedLatch.signal()
firstReadLatch.wait()
threadLocal.set(1)

// Ensure we can see our read immediately, then wait for the test thread to verify. This races
// with the set(2) in the test thread, but that's fine. We'll double-check the value later.
valueFromThread = threadLocal.get()
firstReadDoneLatch.signal()
secondReadLatch.wait()

// Read one last time since now the test thread's second write is done.
valueFromThread = threadLocal.get()
}
thread.start()

// Wait for the other thread to start, then both threads set the value to something different
// at the same time.
threadStartedLatch.wait()
firstReadLatch.signal()
threadLocal.set(2)

// Wait for the background thread to finish setting value, then ensure that both threads see
// independent values.
firstReadDoneLatch.wait()
assertEquals(1, valueFromThread)
assertEquals(2, threadLocal.get())

// Change the value in this thread then read it again from the background thread.
threadLocal.set(3)
secondReadLatch.signal()
thread.join()
assertEquals(1, valueFromThread)
}

private fun NSThread.join() {
while (!isFinished()) {
// Avoid being optimized out.
// Time interval is in seconds.
NSThread.sleepForTimeInterval(1.0 / 1000)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,13 @@ package com.squareup.workflow1.internal
internal actual typealias Lock = Any

internal actual inline fun <R> Lock.withLock(block: () -> R): R = block()

internal actual class ThreadLocal<T>(private var value: T) {
actual fun get(): T = value
actual fun set(value: T) {
this.value = value
}
}

internal actual fun <T> threadLocalOf(initialValue: () -> T): ThreadLocal<T> =
ThreadLocal(initialValue())
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,8 @@ package com.squareup.workflow1.internal
internal actual typealias Lock = Any

internal actual inline fun <R> Lock.withLock(block: () -> R): R = synchronized(this, block)

internal actual typealias ThreadLocal<T> = java.lang.ThreadLocal<T>

internal actual fun <T> threadLocalOf(initialValue: () -> T): ThreadLocal<T> =
ThreadLocal.withInitial(initialValue)
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package com.squareup.workflow1

import java.util.concurrent.CountDownLatch

/**
* Returns the maximum number of threads that can be run in parallel on the host system, rounded
* down to the nearest even number, and at least 2.
*/
internal fun calculateSaturatingTestThreadCount(minThreads: Int) =
Runtime.getRuntime().availableProcessors().let {
if (it.mod(2) != 0) it - 1 else it
}.coerceAtLeast(minThreads)

/**
* Calls [CountDownLatch.await] in a loop until count is zero, even if the thread gets
* interrupted.
*/
@Suppress("CheckResult")
internal fun CountDownLatch.awaitUntilDone() {
while (count > 0) {
try {
await()
} catch (e: InterruptedException) {
// Continue
}
}
}
Loading
Loading