-
Notifications
You must be signed in to change notification settings - Fork 35
ThreadSafeVarInfo
and threadid
#924
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Agreed. I think locks are the long-term solution for proper thread-safety, but that's a biiiiig operation to implement. I would be in favour of some quick fix with |
I would be in favour of removing TSVI, actually. 😬 |
A shame to lose a feature that I think someone on Slack mentioned using just a couple of weeks ago, no? |
I don't feel bad about it given that there are alternatives that work perfectly fine, and the Julia page linked above even suggests that
|
Uh oh!
There was an error while loading. Please reload this page.
Introduction
Currently,
ThreadSafeVarInfo
creates an array of lengthThreads.nthreads()
to store logp values accumulated in each thread:DynamicPPL.jl/src/threadsafe.jl
Lines 11 to 13 in cdeb657
It then adds to
logps[Threads.threadid()]
:DynamicPPL.jl/src/threadsafe.jl
Lines 24 to 31 in cdeb657
Although
ThreadSafeVarInfo
has been changed a bit by the accumulators PR (#885), the thread ID indexing behaviour described above still remains.Now, this has worked fine up until Julia 1.11. However, in Julia 1.12, this breaks, because
Threads.threadid()
returns a value that is larger thanThreads.nthreads()
— as seen in CI of #921 (link to failing run) and more clearly demonstrated here:Julia 1.12, 1 thread
Julia 1.12, 4 threads
Possible solutions
1. Use
maxthreadid()
instead ofnthreads()
This would be the quickest, hackiest, fix. It is not ideal, but it is not really any worse than the current situation, and could tide us over for some time while we figure out a proper solution.
(Actually, there is an even more hacky fix: in
acclogp
, we can index into the vector withthreadid() - 1
instead ofthreadid()
. I assume we don't want to go there.)2. Rewrite
ThreadSafeVarInfo
to use a lockProbably the best, but lots of work. In my opinion, I don't think that this amount of work is worth it, unless it allowed us to extend the 'thread safety' to assume-statements (and not just observe-statements).
3. Disallow tilde-statements inside
Threads.@threads
Right now, we allow observe-statements to happen inside
Threads.@threads
(but not assume-statements). Observe-statements can, of course, be replaced with calls to@addlogprob!
. For example, the following model breaks on Julia 1.12 (with any number of threads):The following model, however, is equivalent (and the use of
Threads.@spawn
is "officially correct", see https://julialang.org/blog/2023/07/PSA-dont-use-threadid/):And of course, people can use whatever threading library they like (e.g. FLoops.jl) too, as long as there are no tilde-statements in the parallelised code.
Note that if we disallowed multithreaded tilde-statements, this also implies that
ThreadSafeVarInfo
could be entirely removed.The text was updated successfully, but these errors were encountered: