Skip to content

ThreadSafeVarInfo and threadid #924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
penelopeysm opened this issue May 16, 2025 · 4 comments
Open

ThreadSafeVarInfo and threadid #924

penelopeysm opened this issue May 16, 2025 · 4 comments

Comments

@penelopeysm
Copy link
Member

penelopeysm commented May 16, 2025

Introduction

Currently, ThreadSafeVarInfo creates an array of length Threads.nthreads() to store logp values accumulated in each thread:

function ThreadSafeVarInfo(vi::AbstractVarInfo)
return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()])
end

It then adds to logps[Threads.threadid()]:

function acclogp!!(vi::ThreadSafeVarInfo, logp)
vi.logps[Threads.threadid()] += logp
return vi
end
function acclogp!!(vi::ThreadSafeVarInfoWithRef, logp)
vi.logps[Threads.threadid()][] += logp
return vi
end

Although ThreadSafeVarInfo has been changed a bit by the accumulators PR (#885), the thread ID indexing behaviour described above still remains.

Now, this has worked fine up until Julia 1.11. However, in Julia 1.12, this breaks, because Threads.threadid() returns a value that is larger than Threads.nthreads() — as seen in CI of #921 (link to failing run) and more clearly demonstrated here:

Julia 1.12, 1 thread

julia> versioninfo()
Julia Version 1.12.0-beta3
Commit faca79b503a (2025-05-12 06:47 UTC)
Build Info:
  Official https://julialang.org release
Platform Info:
  OS: macOS (arm64-apple-darwin24.0.0)
  CPU: 10 × Apple M1 Pro
  WORD_SIZE: 64
  LLVM: libLLVM-18.1.7 (ORCJIT, apple-m1)
  GC: Built with stock GC
Threads: 1 default, 1 interactive, 1 GC (on 8 virtual cores)

julia> Threads.nthreads()
1

julia> Threads.@threads for i in 1:Threads.nthreads(); println(Threads.threadid()); end
2

julia> Threads.maxthreadid()
2

Julia 1.12, 4 threads

julia> versioninfo()
Julia Version 1.12.0-beta3
Commit faca79b503a (2025-05-12 06:47 UTC)
Build Info:
  Official https://julialang.org release
Platform Info:
  OS: macOS (arm64-apple-darwin24.0.0)
  CPU: 10 × Apple M1 Pro
  WORD_SIZE: 64
  LLVM: libLLVM-18.1.7 (ORCJIT, apple-m1)
  GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 8 virtual cores)

julia> Threads.nthreads()
4

julia> Threads.@threads for i in 1:Threads.nthreads(); println(Threads.threadid()); end
2
5
3
4

julia> Threads.maxthreadid()
8

Possible solutions

1. Use maxthreadid() instead of nthreads()

This would be the quickest, hackiest, fix. It is not ideal, but it is not really any worse than the current situation, and could tide us over for some time while we figure out a proper solution.

(Actually, there is an even more hacky fix: in acclogp, we can index into the vector with threadid() - 1 instead of threadid(). I assume we don't want to go there.)

2. Rewrite ThreadSafeVarInfo to use a lock

Probably the best, but lots of work. In my opinion, I don't think that this amount of work is worth it, unless it allowed us to extend the 'thread safety' to assume-statements (and not just observe-statements).

3. Disallow tilde-statements inside Threads.@threads

Right now, we allow observe-statements to happen inside Threads.@threads (but not assume-statements). Observe-statements can, of course, be replaced with calls to @addlogprob!. For example, the following model breaks on Julia 1.12 (with any number of threads):

julia> @model function f(x)
           a ~ Normal()
           Threads.@threads for i in eachindex(x)
               x[i] ~ Normal(a)
           end
       end
f (generic function with 2 methods)

julia> model = f(Float64.(1:10))
Model{typeof(f), (:x,), (), (), Tuple{Vector{Float64}}, Tuple{}, DefaultContext}(f, (x = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],), NamedTuple(), DefaultContext())

julia> vi = VarInfo(model)
ERROR: TaskFailedException
[...]

The following model, however, is equivalent (and the use of Threads.@spawn is "officially correct", see https://julialang.org/blog/2023/07/PSA-dont-use-threadid/):

julia> @model function g(x)
           a ~ Normal()
           logps = map(x) do xi
               Threads.@spawn logpdf(Normal(a), xi)
           end
           @addlogprob! sum(fetch.(logps))
       end
g (generic function with 2 methods)

julia> model = g(Float64.(1:10))
Model{typeof(g), (:x,), (), (), Tuple{Vector{Float64}}, Tuple{}, DefaultContext}(g, (x = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],), NamedTuple(), DefaultContext())

julia> vi = VarInfo(model)
VarInfo{@NamedTuple{a::DynamicPPL.Metadata{Dict{VarName{:a, typeof(identity)}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:a, typeof(identity)}}, Vector{Float64}}}, Float64}((a = DynamicPPL.Metadata{Dict{VarName{:a, typeof(identity)}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:a, typeof(identity)}}, Vector{Float64}}(Dict(a => 1), [a], UnitRange{Int64}[1:1], [-0.011060756850626022], Normal{Float64}[Normal{Float64}=0.0, σ=1.0)], [0], Dict{String, BitVector}("del" => [0], "trans" => [0])),), Base.RefValue{Float64}(-203.2173383639174), Base.RefValue{Int64}(0))

And of course, people can use whatever threading library they like (e.g. FLoops.jl) too, as long as there are no tilde-statements in the parallelised code.

Note that if we disallowed multithreaded tilde-statements, this also implies that ThreadSafeVarInfo could be entirely removed.

@mhauru
Copy link
Member

mhauru commented May 19, 2025

[Locks are] Probably the best, but lots of work. In my opinion, I don't think that this amount of work is worth it, unless it allowed us to extend the 'thread safety' to assume-statements (and not just observe-statements).

Agreed. I think locks are the long-term solution for proper thread-safety, but that's a biiiiig operation to implement.

I would be in favour of some quick fix with maxthreadid or some sort of task ID (is that a thing?).

@penelopeysm
Copy link
Member Author

I would be in favour of removing TSVI, actually. 😬

@mhauru
Copy link
Member

mhauru commented May 19, 2025

A shame to lose a feature that I think someone on Slack mentioned using just a couple of weeks ago, no?

@penelopeysm
Copy link
Member Author

penelopeysm commented May 19, 2025

I don't feel bad about it given that there are alternatives that work perfectly fine, and the Julia page linked above even suggests that Threads.@threads isn't a good way of writing concurrent code. (https://julialang.org/blog/2023/07/PSA-dont-use-threadid/#quickfix_replace_threads_with_threads_static)

If you want a recipe that can replace the above buggy one with something that can be written using only the Base.Threads module, we recommend moving away from @threads, and instead working directly with @spawn to create and manage tasks. The reason is that @threads does not have any builtin mechanisms for managing and merging the results of work from different threads, whereas tasks can manage and return their own state in a safe way.
Tasks creating and returning their own state is inherently safer than the spawner of parallel tasks setting up state for spawned tasks to read from and write to.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants