From 9fdf18d195bc424884f06002061ffe3875609e3d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 27 Apr 2025 23:05:20 +0100 Subject: [PATCH 1/7] DynamicPPL 0.36 --- HISTORY.md | 20 ++++++++++++++++++++ Project.toml | 4 ++-- src/mcmc/Inference.jl | 5 ++--- test/Project.toml | 4 ++-- 4 files changed, 26 insertions(+), 7 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 70a6bf824..926649949 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,23 @@ +# Release 0.37.2 + +DynamicPPL compatibility has been bumped to 0.36. +This brings with it a number of changes: the ones most likely to affect you are submodel prefixing and conditioning. +Variables in submodels are now represented correctly with field accessors. +For example: + +```julia +using Turing +@model inner() = x ~ Normal() +@model outer() = a ~ to_submodel(inner()) +``` + +`keys(VarInfo(outer()))` now returns `[@varname(a.x)]` instead of `[@varname(var"a.x")]` + +Furthermore, you can now either condition on the outer model like `outer() | (@varname(a.x) => 1.0)`, or the inner model like `inner() | (@varname(x) => 1.0)`. +If you use the conditioned inner model as a submodel, the conditioning will still apply correctly. + +Please see [the DynamicPPL release notes](https://github.com/TuringLang/DynamicPPL.jl/releases/tag/v0.36.0) for fuller details. + # Release 0.37.1 `maximum_a_posteriori` and `maximum_likelihood` now perform sanity checks on the model before running the optimisation. diff --git a/Project.toml b/Project.toml index 9702cfde1..234a44293 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.37.1" +version = "0.37.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -62,7 +62,7 @@ Distributions = "0.25.77" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.35" +DynamicPPL = "0.36" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3" Libtask = "0.8.8" diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index de87a5b39..0cbb45b48 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -4,9 +4,8 @@ using ..Essential using DynamicPPL: Metadata, VarInfo, - TypedVarInfo, # TODO(mhauru) all_varnames_grouped_by_symbol isn't exported by DPPL, because it is only - # implemented for TypedVarInfo. It is used by mh.jl. Either refactor mh.jl to not use it + # implemented for NTVarInfo. It is used by mh.jl. Either refactor mh.jl to not use it # or implement it for other VarInfo types and export it from DPPL. all_varnames_grouped_by_symbol, syms, @@ -161,7 +160,7 @@ function externalsampler( end # TODO: make a nicer `set_namedtuple!` and move these functions to DynamicPPL. -function DynamicPPL.unflatten(vi::TypedVarInfo, θ::NamedTuple) +function DynamicPPL.unflatten(vi::DynamicPPL.NTVarInfo, θ::NamedTuple) set_namedtuple!(deepcopy(vi), θ) return vi end diff --git a/test/Project.toml b/test/Project.toml index 36b7ebdec..df0af4c97 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -40,7 +40,7 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" [compat] AbstractMCMC = "5" -AbstractPPL = "0.9, 0.10" +AbstractPPL = "0.9, 0.10, 0.11" AdvancedMH = "0.6, 0.7, 0.8" AdvancedPS = "=0.6.0" AdvancedVI = "0.2" @@ -52,7 +52,7 @@ Combinatorics = "1" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -DynamicPPL = "0.35" +DynamicPPL = "0.36" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" HypothesisTests = "0.11" From 953b5fd4e0d0d632091cd6cbe749ec3fc6281ef9 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 27 Apr 2025 23:43:10 +0100 Subject: [PATCH 2/7] Fix prefixing test and docs --- docs/src/api.md | 2 +- test/optimisation/Optimisation.jl | 10 +++------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 55b09a9d1..01f022e7e 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -40,7 +40,7 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu | `@model` | [`DynamicPPL.@model`](@extref) | Define a probabilistic model | | `@varname` | [`AbstractPPL.@varname`](@extref) | Generate a `VarName` from a Julia expression | | `to_submodel` | [`DynamicPPL.to_submodel`](@extref) | Define a submodel | -| `prefix` | [`DynamicPPL.prefix`](@extref) | Prefix all variable names in a model with a given symbol | +| `prefix` | [`DynamicPPL.prefix`](@extref) | Prefix all variable names in a model with a given VarName | | `LogDensityFunction` | [`DynamicPPL.LogDensityFunction`](@extref) | A struct containing all information about how to evaluate a model. Mostly for advanced users | ### Inference diff --git a/test/optimisation/Optimisation.jl b/test/optimisation/Optimisation.jl index 64ab6a6fe..9894d621c 100644 --- a/test/optimisation/Optimisation.jl +++ b/test/optimisation/Optimisation.jl @@ -71,13 +71,9 @@ using Turing end @testset "With prefixes" begin - function prefix_μ(model) - return DynamicPPL.contextualize( - model, DynamicPPL.PrefixContext{:inner}(model.context) - ) - end - m1 = prefix_μ(model1(x)) - m2 = prefix_μ(model2() | (var"inner.x"=x,)) + vn = @varname(inner) + m1 = prefix(model1(x), vn) + m2 = prefix((model2() | (x=x,)), vn) ctx = Turing.Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext()) @test Turing.Optimisation.OptimLogDensity(m1, ctx)(w) == Turing.Optimisation.OptimLogDensity(m2, ctx)(w) From 0fb5e2ab2704c8f919cc56ba34dd7712ef7bea85 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 28 Apr 2025 01:09:57 +0100 Subject: [PATCH 3/7] Fix deprecation warning for VarName(::Symbol) --- src/mcmc/gibbs.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index db133abc8..09016765a 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -232,9 +232,11 @@ end wrap_in_sampler(x::AbstractMCMC.AbstractSampler) = x wrap_in_sampler(x::InferenceAlgorithm) = DynamicPPL.Sampler(x) -to_varname_list(x::Union{VarName,Symbol}) = [VarName(x)] +to_varname(x::VarName) = x +to_varname(x::Symbol) = VarName{x}() +to_varname_list(x::Union{VarName,Symbol}) = [to_varname(x)] # Any other value is assumed to be an iterable of VarNames and Symbols. -to_varname_list(t) = collect(map(VarName, t)) +to_varname_list(t) = collect(map(to_varname, t)) """ Gibbs From 22c510c5f155a04a62a35d698ac9c0d0ddf6334a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 28 Apr 2025 01:39:07 +0100 Subject: [PATCH 4/7] Allow GibbsContext to wrap PrefixContext (but only PrefixContext) --- src/mcmc/gibbs.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 09016765a..3355e77c5 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -21,6 +21,11 @@ isgibbscomponent(spl::ExternalSampler) = isgibbscomponent(spl.sampler) isgibbscomponent(::AdvancedHMC.HMC) = true isgibbscomponent(::AdvancedMH.MetropolisHastings) = true +function can_be_wrapped(ctx::DynamicPPL.AbstractContext) + return DynamicPPL.NodeTrait(ctx) isa DynamicPPL.IsLeaf +end +can_be_wrapped(ctx::DynamicPPL.PrefixContext) = can_be_wrapped(ctx.context) + # Basically like a `DynamicPPL.FixedContext` but # 1. Hijacks the tilde pipeline to fix variables. # 2. Computes the log-probability of the fixed variables. @@ -68,14 +73,14 @@ struct GibbsContext{VNs,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractCont context::Ctx function GibbsContext{VNs}(global_varinfo, context) where {VNs} - if !(DynamicPPL.NodeTrait(context) isa DynamicPPL.IsLeaf) + if !can_be_wrapped(context) error("GibbsContext can only wrap a leaf context, not a $(context).") end return new{VNs,typeof(global_varinfo),typeof(context)}(global_varinfo, context) end function GibbsContext(target_varnames, global_varinfo, context) - if !(DynamicPPL.NodeTrait(context) isa DynamicPPL.IsLeaf) + if !can_be_wrapped(context) error("GibbsContext can only wrap a leaf context, not a $(context).") end if any(vn -> DynamicPPL.getoptic(vn) != identity, target_varnames) From ab0dd4b75084121dcc21cf3dbc230d14a951f65a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 28 Apr 2025 01:42:49 +0100 Subject: [PATCH 5/7] Bump minor version instead --- HISTORY.md | 2 +- Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 926649949..22ef03471 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,4 +1,4 @@ -# Release 0.37.2 +# Release 0.38.0 DynamicPPL compatibility has been bumped to 0.36. This brings with it a number of changes: the ones most likely to affect you are submodel prefixing and conditioning. diff --git a/Project.toml b/Project.toml index 234a44293..82d32ed82 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.37.2" +version = "0.38.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 89bf31af4cb7fd4f9b10c85197779dad23ac05bd Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 4 May 2025 00:06:38 +0100 Subject: [PATCH 6/7] Enable non-identity VarNames in Gibbs Closes #2403 --- src/mcmc/gibbs.jl | 72 ++++++++++++++++++++++++++++++++-------------- test/mcmc/gibbs.jl | 22 ++++++-------- 2 files changed, 60 insertions(+), 34 deletions(-) diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 3355e77c5..9af0427d4 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -59,8 +59,13 @@ for type stability of `tilde_assume`. # Fields $(FIELDS) """ -struct GibbsContext{VNs,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext} <: - DynamicPPL.AbstractContext +struct GibbsContext{ + VNs<:Tuple{Vararg{VarName}},GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext +} <: DynamicPPL.AbstractContext + """ + the VarNames being sampled + """ + target_varnames::VNs """ a `Ref` to the global `AbstractVarInfo` object that holds values for all variables, both those fixed and those being sampled. We use a `Ref` because this field may need to be @@ -72,26 +77,14 @@ struct GibbsContext{VNs,GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractCont """ context::Ctx - function GibbsContext{VNs}(global_varinfo, context) where {VNs} - if !can_be_wrapped(context) - error("GibbsContext can only wrap a leaf context, not a $(context).") - end - return new{VNs,typeof(global_varinfo),typeof(context)}(global_varinfo, context) - end - function GibbsContext(target_varnames, global_varinfo, context) if !can_be_wrapped(context) error("GibbsContext can only wrap a leaf context, not a $(context).") end - if any(vn -> DynamicPPL.getoptic(vn) != identity, target_varnames) - msg = - "All Gibbs target variables must have identity lenses. " * - "For example, you can't have `@varname(x.a[1])` as a target variable, " * - "only `@varname(x)`." - error(msg) - end - vn_sym = tuple(unique((DynamicPPL.getsym(vn) for vn in target_varnames))...) - return new{vn_sym,typeof(global_varinfo),typeof(context)}(global_varinfo, context) + target_varnames = tuple(target_varnames...) # Allow vectors. + return new{typeof(target_varnames),typeof(global_varinfo),typeof(context)}( + target_varnames, global_varinfo, context + ) end end @@ -101,8 +94,10 @@ end DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent() DynamicPPL.childcontext(context::GibbsContext) = context.context -function DynamicPPL.setchildcontext(context::GibbsContext{VNs}, childcontext) where {VNs} - return GibbsContext{VNs}(Ref(context.global_varinfo[]), childcontext) +function DynamicPPL.setchildcontext(context::GibbsContext, childcontext) + return GibbsContext( + context.target_varnames, Ref(context.global_varinfo[]), childcontext + ) end get_global_varinfo(context::GibbsContext) = context.global_varinfo[] @@ -134,7 +129,9 @@ function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarNa return map(Base.Fix1(get_conditioned_gibbs, context), vns) end -is_target_varname(::GibbsContext{VNs}, ::VarName{sym}) where {VNs,sym} = sym in VNs +function is_target_varname(ctx::GibbsContext, vn::VarName) + return any(Base.Fix2(subsumes, vn), ctx.target_varnames) +end function is_target_varname(context::GibbsContext, vns::AbstractArray{<:VarName}) num_target = count(Iterators.map(Base.Fix1(is_target_varname, context), vns)) @@ -150,6 +147,37 @@ end # Tilde pipeline function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) child_context = DynamicPPL.childcontext(context) + + # Note that `child_context` may contain `PrefixContext`s -- in which case + # we need to make sure that vn is appropriately prefixed before we handle + # the `GibbsContext` behaviour below. For example, consider the following: + # @model inner() = x ~ Normal() + # @model outer() = a ~ to_submodel(inner()) + # If we run this with `Gibbs(@varname(a.x) => MH())`, then when we are + # executing the submodel, the `context` will contain the `@varname(a.x)` + # variable; `child_context` will contain `PrefixContext(@varname(a))`; and + # `vn` will just be `@varname(x)`. If we just simply run + # `is_target_varname(context, vn)`, it will return false, and everything + # will be messed up. + # TODO(penelopeysm): This 'problem' could be solved if we made GibbsContext a + # leaf context and wrapped the PrefixContext _above_ the GibbsContext, so + # that the prefixing would be handled by tilde_assume(::PrefixContext, ...) + # _before_ we hit this method. + # In the current state of GibbsContext, doing this would require + # special-casing the way PrefixContext is used to wrap the leaf context. + # This is very inconvenient because PrefixContext's behaviour is defined in + # DynamicPPL, and we would basically have to create a new method in Turing + # and override it for GibbsContext. Indeed, a better way to do this would + # be to make GibbsContext a leaf context. In this case, we would be able to + # rely on the existing behaviour of DynamicPPL.make_evaluate_args_and_kwargs + # to correctly wrap the PrefixContext around the GibbsContext. This is very + # tricky to correctly do now, but once we remove the other leaf contexts + # (i.e. PriorContext and LikelihoodContext), we should be able to do this. + # This is already implemented in + # https://github.com/TuringLang/DynamicPPL.jl/pull/885/ but not yet + # released. Exciting! + vn, child_context = DynamicPPL.prefix_and_strip_contexts(child_context, vn) + return if is_target_varname(context, vn) # Fall back to the default behavior. DynamicPPL.tilde_assume(child_context, right, vn, vi) @@ -182,6 +210,8 @@ function DynamicPPL.tilde_assume( ) # See comment in the above, rng-less version of this method for an explanation. child_context = DynamicPPL.childcontext(context) + vn, child_context = DynamicPPL.prefix_and_strip_contexts(child_context, vn) + return if is_target_varname(context, vn) DynamicPPL.tilde_assume(rng, child_context, sampler, right, vn, vi) elseif has_conditioned_gibbs(context, vn) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 84c8b7f57..a18a75c98 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -160,10 +160,6 @@ end return Inference.setparams_varinfo!!(model, unwrap_sampler(sampler), state, params) end - function target_vns(::Inference.GibbsContext{VNs}) where {VNs} - return VNs - end - # targets_and_algs will be a list of tuples, where the first element is the target_vns # of a component sampler, and the second element is the component sampler itself. # It is modified by the capture_targets_and_algs function. @@ -174,7 +170,7 @@ end return nothing end if context isa Inference.GibbsContext - push!(targets_and_algs, (target_vns(context), sampler)) + push!(targets_and_algs, (context.target_varnames, sampler)) end return capture_targets_and_algs(sampler, DynamicPPL.childcontext(context)) end @@ -240,14 +236,14 @@ end chain = sample(test_model(-1), sampler, 2) expected_targets_and_algs_per_iteration = [ - ((:s,), mh), - ((:s, :m), mh), - ((:m,), pg), - ((:xs,), hmc), - ((:ys,), nuts), - ((:ys,), nuts), - ((:xs, :ys), hmc), - ((:s,), mh), + ((@varname(s),), mh), + ((@varname(s), @varname(m)), mh), + ((@varname(m),), pg), + ((@varname(xs),), hmc), + ((@varname(ys),), nuts), + ((@varname(ys),), nuts), + ((@varname(xs), @varname(ys)), hmc), + ((@varname(s),), mh), ] @test targets_and_algs == vcat( expected_targets_and_algs_per_iteration, expected_targets_and_algs_per_iteration From 697c208e45b900b113d94240396622084ac3e16a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 4 May 2025 02:51:31 +0100 Subject: [PATCH 7/7] Add Gibbs tests for non-identity VarNames and submodels --- test/mcmc/gibbs.jl | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index a18a75c98..0ba2e85ff 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -723,6 +723,39 @@ end end end + @testset "non-identity varnames" begin + struct Wrap{T} + a::T + end + @model function model1(::Type{T}=Float64) where {T} + x = Vector{T}(undef, 1) + x[1] ~ Normal() + y = Wrap{T}(0.0) + return y.a ~ Normal() + end + model = model1() + spl = Gibbs(@varname(x[1]) => HMC(0.5, 10), @varname(y.a) => MH()) + @test sample(model, spl, 10) isa MCMCChains.Chains + spl = Gibbs((@varname(x[1]), @varname(y.a)) => HMC(0.5, 10)) + @test sample(model, spl, 10) isa MCMCChains.Chains + end + + @testset "submodels" begin + @model inner() = x ~ Normal() + @model function outer() + a ~ to_submodel(inner()) + _ignored ~ to_submodel(prefix(inner(), @varname(b)), false) + return _also_ignored ~ to_submodel(inner(), false) + end + model = outer() + spl = Gibbs( + @varname(a.x) => HMC(0.5, 10), @varname(b.x) => MH(), @varname(x) => MH() + ) + @test sample(model, spl, 10) isa MCMCChains.Chains + spl = Gibbs((@varname(a.x), @varname(b.x), @varname(x)) => MH()) + @test sample(model, spl, 10) isa MCMCChains.Chains + end + @testset "CSMC + ESS" begin rng = Random.default_rng() model = MoGtest_default