diff --git a/HISTORY.md b/HISTORY.md index ebc29dba0..c6585ec49 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,10 @@ # DynamicPPL Changelog +## 0.36.8 + +Make `ThreadSafeVarInfo` hold a total of `Threads.nthreads() * 2` logp values, instead of just `Threads.nthreads()`. +This fix helps to paper over the cracks in using `threadid()` to index into the `ThreadSafeVarInfo` object. + ## 0.36.7 Added compatibility with MCMCChains 7.0. diff --git a/Project.toml b/Project.toml index 00a4f8f93..2fc1d984c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.36.7" +version = "0.36.8" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index abf14b8fc..3ae425896 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -648,10 +648,12 @@ end # Threadsafe stuff. # For `SimpleVarInfo` we don't really need `Ref` so let's not use it. function ThreadSafeVarInfo(vi::SimpleVarInfo) - return ThreadSafeVarInfo(vi, zeros(typeof(getlogp(vi)), Threads.nthreads())) + return ThreadSafeVarInfo(vi, zeros(typeof(getlogp(vi)), Threads.nthreads() * 2)) end function ThreadSafeVarInfo(vi::SimpleVarInfo{<:Any,<:Ref}) - return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()]) + return ThreadSafeVarInfo( + vi, [Ref(zero(getlogp(vi))) for _ in 1:(Threads.nthreads() * 2)] + ) end has_varnamedvector(vi::SimpleVarInfo) = vi.values isa VarNamedVector diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 2dc2645de..458e5bca3 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -9,7 +9,9 @@ struct ThreadSafeVarInfo{V<:AbstractVarInfo,L} <: AbstractVarInfo logps::L end function ThreadSafeVarInfo(vi::AbstractVarInfo) - return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()]) + return ThreadSafeVarInfo( + vi, [Ref(zero(getlogp(vi))) for _ in 1:(Threads.nthreads() * 2)] + ) end ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 72c439db8..ededf78b0 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -5,7 +5,7 @@ @test threadsafe_vi.varinfo === vi @test threadsafe_vi.logps isa Vector{typeof(Ref(getlogp(vi)))} - @test length(threadsafe_vi.logps) == Threads.nthreads() + @test length(threadsafe_vi.logps) == Threads.nthreads() * 2 @test all(iszero(x[]) for x in threadsafe_vi.logps) end