diff --git a/.gitignore b/.gitignore index 1388e96ac..28fed4507 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ .DS_Store /Manifest.toml /dev/ +/test/gdemo_default.jls \ No newline at end of file diff --git a/Project.toml b/Project.toml index 210a7cc64..030e81e8d 100644 --- a/Project.toml +++ b/Project.toml @@ -23,6 +23,7 @@ julia = "1" AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232" +Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2" @@ -38,6 +39,7 @@ ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -48,4 +50,4 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs", "Zygote"] +test = ["AdvancedHMC", "AdvancedMH", "Distributed", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "Serialization", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs", "Zygote"] diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 6f1ca2d7c..d3d39bb4e 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -3,9 +3,9 @@ module DynamicPPL using AbstractMCMC: AbstractSampler, AbstractChains, AbstractModel using Distributions using Bijectors -using MacroTools import AbstractMCMC +import MacroTools import ZygoteRules import Random @@ -51,7 +51,6 @@ export AbstractVarInfo, inspace, subsumes, # Compiler - ModelGen, @model, @varname, # Utilities @@ -59,17 +58,13 @@ export AbstractVarInfo, reconstruct, reconstruct!, Sample, - Chain, init, vectorize, set_resume!, # Model - ModelGen, Model, getmissings, getargnames, - getdefaults, - getgenerator, # Samplers Sampler, SampleFromPrior, diff --git a/src/compiler.jl b/src/compiler.jl index 5b56ce2c5..0f3148b1a 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -49,17 +49,17 @@ Macro to specify a probabilistic model. If `warn` is `true`, a warning is displayed if internal variable names are used in the model definition. -# Example +# Examples Model definition: ```julia -@model function model_generator(x = default_x, y) +@model function model(x, y = 42) ... end ``` -To generate a `Model`, call `model_generator(x_value)`. +To generate a `Model`, call `model(xvalue)` or `model(xvalue, yvalue)`. """ macro model(expr, warn=true) esc(model(expr, warn)) @@ -69,7 +69,9 @@ function model(expr, warn) modelinfo = build_model_info(expr) # Generate main body - modelinfo[:modelbody] = generate_mainbody(modelinfo[:body], modelinfo[:modelargs], warn) + modelinfo[:body] = generate_mainbody( + modelinfo[:modeldef][:body], modelinfo[:allargs_exprs], warn + ) return build_output(modelinfo) end @@ -80,87 +82,74 @@ end Builds the `model_info` dictionary from the model's expression. """ function build_model_info(input_expr) - # Extract model name (:name), arguments (:args), (:kwargs) and definition (:body) + # Break up the model definition and extract its name, arguments, and function body modeldef = MacroTools.splitdef(input_expr) - # Function body of the model is empty + + # Print a warning if function body of the model is empty warn_empty(modeldef[:body]) - # Construct model_info dictionary - - # Extracting the argument symbols from the model definition - combinedargs = vcat(modeldef[:args], modeldef[:kwargs]) - arg_syms = map(combinedargs) do arg - # @model demo(x) - if (arg isa Symbol) - arg - # @model demo(::Type{T}) where {T} - elseif MacroTools.@capture(arg, ::Type{T_} = Tval_) - T - # @model demo(x::T = 1) - elseif MacroTools.@capture(arg, x_::T_ = val_) - x - # @model demo(x = 1) - elseif MacroTools.@capture(arg, x_ = val_) - x - else - throw(ArgumentError("Unsupported argument $arg to the `@model` macro.")) - end - end - if length(arg_syms) == 0 - args_nt = :(NamedTuple()) - else - nt_type = Expr(:curly, :NamedTuple, - Expr(:tuple, QuoteNode.(arg_syms)...), - Expr(:curly, :Tuple, [:(Core.Typeof($x)) for x in arg_syms]...) + + ## Construct model_info dictionary + + # Shortcut if the model does not have any arguments + if !haskey(modeldef, :args) && !haskey(modeldef, :kwargs) + modelinfo = Dict( + :allargs_exprs => [], + :allargs_syms => [], + :allargs_namedtuple => NamedTuple(), + :defaults_namedtuple => NamedTuple(), + :modeldef => modeldef, ) - args_nt = Expr(:call, :($namedtuple), nt_type, Expr(:tuple, arg_syms...)) + return modelinfo end - args = map(combinedargs) do arg - if (arg isa Symbol) - arg - elseif MacroTools.@capture(arg, ::Type{T_} = Tval_) - if in(T, modeldef[:whereparams]) - S = :Any - else - ind = findfirst(modeldef[:whereparams]) do x - MacroTools.@capture(x, T1_ <: S_) && T1 == T - end - ind !== nothing || throw(ArgumentError("Please make sure type parameters are properly used. Every `Type{T}` argument need to have `T` in the a `where` clause")) - end - Expr(:kw, :($T::Type{<:$S}), Tval) - else - arg + + # Extract the positional and keyword arguments from the model definition. + allargs = vcat(modeldef[:args], modeldef[:kwargs]) + + # Split the argument expressions and the default values. + allargs_exprs_defaults = map(allargs) do arg + MacroTools.@match arg begin + (x_ = val_) => (x, val) + x_ => (x, NO_DEFAULT) + end + end + + # Extract the expressions of the arguments, without default values. + allargs_exprs = first.(allargs_exprs_defaults) + + # Extract the names of the arguments. + allargs_syms = map(allargs_exprs_defaults) do (arg, _) + MacroTools.@match arg begin + (::Type{T_}) | (name_::Type{T_}) => T + name_::T_ => name + x_ => x end end - args_nt = to_namedtuple_expr(arg_syms) + # Build named tuple expression of the argument symbols and variables of the same name. + allargs_namedtuple = to_namedtuple_expr(allargs_syms) + + # Extract default values of the positional and keyword arguments. default_syms = [] - default_vals = [] - foreach(combinedargs) do arg - # @model demo(::Type{T}) where {T} - if MacroTools.@capture(arg, ::Type{T_} = Tval_) - push!(default_syms, T) - push!(default_vals, Tval) - # @model demo(x::T = 1) - elseif MacroTools.@capture(arg, x_::T_ = val_) - push!(default_syms, x) - push!(default_vals, val) - # @model demo(x = 1) - elseif MacroTools.@capture(arg, x_ = val_) - push!(default_syms, x) + default_vals = [] + for (sym, (expr, val)) in zip(allargs_syms, allargs_exprs_defaults) + if val !== NO_DEFAULT + push!(default_syms, sym) push!(default_vals, val) end end - defaults_nt = to_namedtuple_expr(default_syms, default_vals) - modelderiv = Dict( - :modelargs => args, - :modelargsyms => arg_syms, - :modelargsnt => args_nt, - :modeldefaultsnt => defaults_nt, + # Build named tuple expression of the argument symbols with default values. + defaults_namedtuple = to_namedtuple_expr(default_syms, default_vals) + + modelinfo = Dict( + :allargs_exprs => allargs_exprs, + :allargs_syms => allargs_syms, + :allargs_namedtuple => allargs_namedtuple, + :defaults_namedtuple => defaults_namedtuple, + :modeldef => modeldef, ) - model_info = merge(modeldef, modelderiv) - return model_info + return modelinfo end """ @@ -312,55 +301,53 @@ hasmissing(T::Type{<:AbstractArray{>:Missing}}) = true hasmissing(T::Type) = false """ - build_output(model_info) + build_output(modelinfo) Builds the output expression. """ -function build_output(model_info) - # Arguments with default values - args = model_info[:modelargs] - # Argument symbols without default values - arg_syms = model_info[:modelargsyms] - # Arguments namedtuple - args_nt = model_info[:modelargsnt] - # Default values of the arguments - # Arguments namedtuple - defaults_nt = model_info[:modeldefaultsnt] - # Model generator name - model_gen = model_info[:name] - # Main body of the model - main_body = model_info[:modelbody] - - unwrap_data_expr = Expr(:block) - for var in arg_syms - push!(unwrap_data_expr.args, - :($var = $(DynamicPPL.matchingvalue)(_sampler, _varinfo, _model.args.$var))) - end +function build_output(modelinfo) + ## Build the anonymous evaluator from the user-provided model definition. + + # Remove the name. + evaluatordef = deepcopy(modelinfo[:modeldef]) + delete!(evaluatordef, :name) + + # Add the internal arguments to the user-specified arguments (positional + keywords). + evaluatordef[:args] = vcat( + [ + :(_rng::$(Random.AbstractRNG)), + :(_model::$(DynamicPPL.Model)), + :(_varinfo::$(DynamicPPL.AbstractVarInfo)), + :(_sampler::$(DynamicPPL.AbstractSampler)), + :(_context::$(DynamicPPL.AbstractContext)), + ], + modelinfo[:allargs_exprs], + ) - @gensym(evaluator, generator) - model_gen_constructor = :($(DynamicPPL.ModelGen){$(Tuple(arg_syms))}($generator, $defaults_nt)) + # Delete the keyword arguments. + evaluatordef[:kwargs] = [] - # construct the user-facing model generator - model_info[:name] = generator - model_info[:body] = :(return $(DynamicPPL.Model)($evaluator, $args_nt, $model_gen_constructor)) - generator_expr = MacroTools.combinedef(model_info) + # Replace the user-provided function body with the version created by DynamicPPL. + evaluatordef[:body] = modelinfo[:body] - return quote - function $evaluator( - _rng::$(Random.AbstractRNG), - _model::$(DynamicPPL.Model), - _varinfo::$(DynamicPPL.AbstractVarInfo), - _sampler::$(DynamicPPL.AbstractSampler), - _context::$(DynamicPPL.AbstractContext), - ) - $unwrap_data_expr - $main_body - end + ## Build the model function. - $(generator_expr) + # Extract the named tuple expression of all arguments and the default values. + allargs_namedtuple = modelinfo[:allargs_namedtuple] + defaults_namedtuple = modelinfo[:defaults_namedtuple] - $(Base).@__doc__ $model_gen = $model_gen_constructor + # Update the function body of the user-specified model. + # We use a name for the anonymous evaluator that does not conflict with other variables. + modeldef = modelinfo[:modeldef] + @gensym evaluator + modeldef[:body] = quote + $evaluator = $(combinedef_anonymous(evaluatordef)) + return $(DynamicPPL.Model)( + $evaluator, $allargs_namedtuple, $defaults_namedtuple + ) end + + return :($(Base).@__doc__ $(MacroTools.combinedef(modeldef))) end diff --git a/src/model.jl b/src/model.jl index d894bd042..739772b02 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,112 +1,70 @@ """ - struct ModelGen{G, defaultnames, Tdefaults} - generator::G - defaults::Tdefaults - end - -A `ModelGen` struct with model generator function of type `G`, and default arguments `defaultnames` -with values `Tdefaults`. -""" -struct ModelGen{G, argnames, defaultnames, Tdefaults} - generator::G - defaults::NamedTuple{defaultnames, Tdefaults} - - function ModelGen{argnames}( - generator::G, - defaults::NamedTuple{defaultnames, Tdefaults} - ) where {G, argnames, defaultnames, Tdefaults} - return new{G, argnames, defaultnames, Tdefaults}(generator, defaults) + struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} + f::F + args::NamedTuple{argnames,Targs} + defaults::NamedTuple{defaultnames,Tdefaults} end -end - -(m::ModelGen)(args...; kwargs...) = m.generator(args...; kwargs...) +A `Model` struct with model evaluation function of type `F`, arguments of names `argnames` +types `Targs`, default arguments of names `defaultnames` with types `Tdefaults`, and missing +arguments `missings`. -""" - getdefaults(modelgen::ModelGen) - -Get a named tuple of the default argument values defined for a model defined by a generating function. -""" -getdefaults(modelgen::ModelGen) = modelgen.defaults +Here `argnames`, `defaultargnames`, and `missings` are tuples of symbols, e.g. `(:a, :b)`. -""" - getargnames(modelgen::ModelGen) - -Get a tuple of the argument names of the `modelgen`. -""" -getargnames(model::ModelGen{_G, argnames}) where {argnames, _G} = argnames - - - -""" - struct Model{F, argnames, Targs, missings} - f::F - args::NamedTuple{argnames, Targs} - modelgen::Tgen - end +An argument with a type of `Missing` will be in `missings` by default. However, in +non-traditional use-cases `missings` can be defined differently. All variables in `missings` +are treated as random variables rather than observations. -A `Model` struct with model evaluation function of type `F`, arguments names `argnames`, arguments -types `Targs`, missing arguments `missings`, and corresponding model generator. `argnames` and -`missings` are tuples of symbols, e.g. `(:a, :b)`. An argument with a type of `Missing` will be in -`missings` by default. However, in non-traditional use-cases `missings` can be defined differently. -All variables in `missings` are treated as random variables rather than observations. +The default arguments are used internally when constructing instances of the same model with +different arguments. -# Example +# Examples ```julia julia> Model(f, (x = 1.0, y = 2.0)) -Model{typeof(f),(),(:x, :y),Tuple{Float64,Float64}}((x = 1.0, y = 2.0)) +Model{typeof(f),(:x, :y),(),(),Tuple{Float64,Float64},Tuple{}}(f, (x = 1.0, y = 2.0), NamedTuple()) -julia> Model{(:y,)}(f, (x = 1.0, y = 2.0)) -Model{typeof(f),(:y,),(:x, :y),Tuple{Float64,Float64}}((x = 1.0, y = 2.0)) +julia> Model(f, (x = 1.0, y = 2.0), (x = 42,)) +Model{typeof(f),(:x, :y),(:x,),(),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) + +julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition of missings +Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F, argnames, Targs, missings, Tgen} <: AbstractModel +struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractModel f::F - args::NamedTuple{argnames, Targs} - modelgen::Tgen + args::NamedTuple{argnames,Targs} + defaults::NamedTuple{defaultnames,Tdefaults} """ - Model{missings}(f, args::NamedTuple, modelgen::ModelGen) + Model{missings}(f, args::NamedTuple, defaults::NamedTuple) - Create a model with evalutation function `f` and missing arguments overwritten by `missings`. + Create a model with evaluation function `f` and missing arguments overwritten by `missings`. """ function Model{missings}( f::F, - args::NamedTuple{argnames, Targs}, - modelgen::Tgen - ) where {missings, F, argnames, Targs, Tgen<:ModelGen} - return new{F, argnames, Targs, missings, Tgen}(f, args, modelgen) + args::NamedTuple{argnames,Targs}, + defaults::NamedTuple{defaultnames,Tdefaults}, + ) where {missings,F,argnames,Targs,defaultnames,Tdefaults} + return new{F,argnames,defaultnames,missings,Targs,Tdefaults}(f, args, defaults) end end """ - Model(f, args::NamedTuple, modelgen::ModelGen) + Model(f, args::NamedTuple[, defaults::NamedTuple = ()]) + +Create a model with evaluation function `f` and missing arguments deduced from `args`. - Create a model with evalutation function `f` and missing arguments deduced from `args`. +Default arguments `defaults` are used internally when constructing instances of the same +model with different arguments. """ @generated function Model( f::F, - args::NamedTuple{argnames, Targs}, - modelgen::ModelGen{_G, argnames} -) where {F, argnames, Targs, _G} + args::NamedTuple{argnames,Targs}, + defaults::NamedTuple = NamedTuple(), +) where {F,argnames,Targs} missings = Tuple(name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing) - return :(Model{$missings}(f, args, modelgen)) -end - - -""" - Model{missings}(modelgen::ModelGen, args::NamedTuple) - -Create a copy of the model described by `modelgen(args...)`, with missing arguments -overwritten by `missings`. -""" -function Model{missings}( - modelgen::ModelGen, - args::NamedTuple{argnames, Targs} -) where {missings, argnames, Targs} - model = modelgen(args...) - return Model{missings}(model.f, args, modelgen) + return :(Model{$missings}(f, args, defaults)) end """ @@ -154,7 +112,7 @@ function evaluate_threadunsafe(rng, model, varinfo, sampler, context) if has_eval_num(sampler) sampler.state.eval_num += 1 end - return model.f(rng, model, varinfo, sampler, context) + return _evaluate(rng, model, varinfo, sampler, context) end """ @@ -174,17 +132,27 @@ function evaluate_threadsafe(rng, model, varinfo, sampler, context) sampler.state.eval_num += 1 end wrapper = ThreadSafeVarInfo(varinfo) - result = model.f(rng, model, wrapper, sampler, context) + result = _evaluate(rng, model, wrapper, sampler, context) setlogp!(varinfo, getlogp(wrapper)) return result end +""" + _evaluate(rng, model::Model, varinfo, sampler, context) + +Evaluate the `model` with the arguments matching the given `sampler` and `varinfo` object. +""" +@generated function _evaluate(rng, model::Model{_F,argnames}, varinfo, sampler, context) where {_F,argnames} + unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] + return :(model.f(rng, model, varinfo, sampler, context, $(unwrap_args...))) +end + """ getargnames(model::Model) Get a tuple of the argument names of the `model`. """ -getargnames(model::Model{_F, argnames}) where {argnames, _F} = argnames +getargnames(model::Model{_F,argnames}) where {argnames,_F} = argnames """ @@ -192,19 +160,11 @@ getargnames(model::Model{_F, argnames}) where {argnames, _F} = argnames Get a tuple of the names of the missing arguments of the `model`. """ -getmissings(model::Model{_F, _a, _T, missings}) where {missings, _F, _a, _T} = missings +getmissings(model::Model{_F,_a,_d,missings}) where {missings,_F,_a,_d} = missings getmissing(model::Model) = getmissings(model) @deprecate getmissing(model) getmissings(model) - -""" - getgenerator(model::Model) - -Get the model generator associated with `model`. -""" -getgenerator(model::Model) = model.modelgen - """ logjoint(model::Model, varinfo::AbstractVarInfo) @@ -214,7 +174,7 @@ See [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) model(varinfo, SampleFromPrior(), DefaultContext()) - return getlogp(varinfo) + return getlogp(varinfo) end """ @@ -226,7 +186,7 @@ See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) model(varinfo, SampleFromPrior(), PriorContext()) - return getlogp(varinfo) + return getlogp(varinfo) end """ @@ -238,5 +198,5 @@ See also [`logjoint`](@ref) and [`logprior`](@ref). """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) model(varinfo, SampleFromPrior(), LikelihoodContext()) - return getlogp(varinfo) + return getlogp(varinfo) end diff --git a/src/prob_macro.jl b/src/prob_macro.jl index 6fb4e89fa..96f461367 100644 --- a/src/prob_macro.jl +++ b/src/prob_macro.jl @@ -22,11 +22,11 @@ function get_exprs(str::String) end function logprob(ex1, ex2) - ptype, modelgen, vi = probtype(ex1, ex2) + ptype, model, vi = probtype(ex1, ex2) if ptype isa Val{:prior} - return logprior(ex1, ex2, modelgen, vi) + return logprior(ex1, ex2, model, vi) elseif ptype isa Val{:likelihood} - return loglikelihood(ex1, ex2, modelgen, vi) + return loglikelihood(ex1, ex2, model, vi) end end @@ -34,13 +34,12 @@ function probtype(ntl::NamedTuple{namesl}, ntr::NamedTuple{namesr}) where {names if :chain in namesr if isdefined(ntr.chain.info, :model) model = ntr.chain.info.model - @assert model isa Model - modelgen = getgenerator(model) elseif isdefined(ntr, :model) - modelgen = ntr.model + model = ntr.model else throw("The model is not defined. Please make sure the model is either saved in the chain or passed on the RHS of |.") end + @assert model isa Model if isdefined(ntr.chain.info, :vi) _vi = ntr.chain.info.vi @assert _vi isa VarInfo @@ -52,14 +51,16 @@ function probtype(ntl::NamedTuple{namesl}, ntr::NamedTuple{namesr}) where {names else vi = nothing end - defaults = getdefaults(modelgen) - valid_arg(arg) = isdefined(ntl, arg) || isdefined(ntr, arg) || - isdefined(defaults, arg) && getfield(defaults, arg) !== missing - @assert all(valid_arg, getargnames(modelgen)) - return Val(:likelihood), modelgen, vi + defaults = model.defaults + @assert all(getargnames(model)) do arg + isdefined(ntl, arg) || isdefined(ntr, arg) || + isdefined(defaults, arg) && getfield(defaults, arg) !== missing + end + return Val(:likelihood), model, vi else @assert isdefined(ntr, :model) - modelgen = ntr.model + model = ntr.model + @assert model isa Model if isdefined(ntr, :varinfo) _vi = ntr.varinfo @assert _vi isa VarInfo @@ -67,16 +68,17 @@ function probtype(ntl::NamedTuple{namesl}, ntr::NamedTuple{namesr}) where {names else vi = nothing end - return probtype(ntl, ntr, modelgen), modelgen, vi + return probtype(ntl, ntr, model), model, vi end end + function probtype( left::NamedTuple{leftnames}, right::NamedTuple{rightnames}, - modelgen::ModelGen{_G, argnames, defaultnames} -) where {leftnames, rightnames, argnames, defaultnames, _G} - defaults = getdefaults(modelgen) - prior_rhs = all(n -> n in (:model, :varinfo) || + model::Model{_F,argnames,defaultnames} +) where {leftnames,rightnames,argnames,defaultnames,_F} + defaults = model.defaults + prior_rhs = all(n -> n in (:model, :varinfo) || n in argnames && getfield(right, n) !== missing, rightnames) function get_arg(arg) if arg in leftnames @@ -118,7 +120,7 @@ missing_arg_error_msg(arg, ::Nothing) = """Variable $arg is not defined and has function logprior( left::NamedTuple, right::NamedTuple, - modelgen::ModelGen, + _model::Model, _vi::Union{Nothing, VarInfo} ) # For model args on the LHS of |, use their passed value but add the symbol to @@ -133,7 +135,7 @@ function logprior( # All `observe` and `dot_observe` calls are no-op in the PriorContext # When all of model args are on the lhs of |, this is also equal to the logjoint. - model = make_prior_model(left, right, modelgen) + model = make_prior_model(left, right, _model) vi = _vi === nothing ? VarInfo(deepcopy(model), PriorContext()) : _vi foreach(keys(vi.metadata)) do n @assert n in keys(left) "Variable $n is not defined." @@ -145,12 +147,12 @@ end @generated function make_prior_model( left::NamedTuple{leftnames}, right::NamedTuple{rightnames}, - modelgen::ModelGen{_G, argnames, defaultnames} -) where {leftnames, rightnames, argnames, defaultnames, _G} + model::Model{_F,argnames,defaultnames} +) where {leftnames,rightnames,argnames,defaultnames,_F} argvals = [] missings = [] warnings = [] - + for argname in argnames if argname in leftnames push!(argvals, :(deepcopy(left.$argname))) @@ -158,18 +160,19 @@ end elseif argname in rightnames push!(argvals, :(right.$argname)) elseif argname in defaultnames - push!(argvals, :(getdefaults(modelgen).$argname)) + push!(argvals, :(model.defaults.$argname)) else push!(warnings, :(@warn($(warn_msg(argname))))) push!(argvals, :(nothing)) end end - # `args` is inserted as properly typed NamedTuple expression; + # `args` is inserted as properly typed NamedTuple expression; # `missings` is splatted into a tuple at compile time and inserted as literal return quote $(warnings...) - Model{$(Tuple(missings))}(modelgen, $(to_namedtuple_expr(argnames, argvals))) + Model{$(Tuple(missings))}(model.f, $(to_namedtuple_expr(argnames, argvals)), + model.defaults) end end @@ -178,10 +181,10 @@ warn_msg(arg) = "Argument $arg is not defined. A value of `nothing` is used." function Distributions.loglikelihood( left::NamedTuple, right::NamedTuple, - modelgen::ModelGen, + _model::Model, _vi::Union{Nothing, VarInfo}, ) - model = make_likelihood_model(left, right, modelgen) + model = make_likelihood_model(left, right, _model) vi = _vi === nothing ? VarInfo(deepcopy(model)) : _vi if isdefined(right, :chain) # Element-wise likelihood for each value in chain @@ -205,11 +208,11 @@ end @generated function make_likelihood_model( left::NamedTuple{leftnames}, right::NamedTuple{rightnames}, - modelgen::ModelGen{_G, argnames, defaultnames} -) where {leftnames, rightnames, argnames, defaultnames, _G} + model::Model{_F,argnames,defaultnames}, +) where {leftnames,rightnames,argnames,defaultnames,_F} argvals = [] missings = [] - + for argname in argnames if argname in leftnames push!(argvals, :(left.$argname)) @@ -217,15 +220,16 @@ end push!(argvals, :(right.$argname)) push!(missings, argname) elseif argname in defaultnames - push!(argvals, :(getdefaults(modelgen).$argname)) + push!(argvals, :(model.defaults.$argname)) else throw("This point should not be reached. Please open an issue in the DynamicPPL.jl repository.") end end - # `args` is inserted as properly typed NamedTuple expression; + # `args` is inserted as properly typed NamedTuple expression; # `missings` is splatted into a tuple at compile time and inserted as literal - return :(Model{$(Tuple(missings))}(modelgen, $(to_namedtuple_expr(argnames, argvals)))) + return :(Model{$(Tuple(missings))}(model.f, $(to_namedtuple_expr(argnames, argvals)), + model.defaults)) end _setval!(vi::TypedVarInfo, c::AbstractChains) = _setval!(vi.metadata, vi, c) diff --git a/src/utils.jl b/src/utils.jl index 0d542bc4b..cd23edc2c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1,35 @@ +# singleton for indicating if no default arguments are present +struct NoDefault end +const NO_DEFAULT = NoDefault() + +# FIXME: This is copied from MacroTools and should be removed when a MacroTools release with +# support for anonymous functions is available (> 0.5.5). +function combinedef_anonymous(dict::Dict) + rtype = get(dict, :rtype, nothing) + params = get(dict, :params, []) + wparams = get(dict, :whereparams, []) + body = MacroTools.block(dict[:body]) + + if isempty(dict[:kwargs]) + arg = :($(dict[:args]...),) + else + arg = Expr(:tuple, Expr(:parameters, dict[:kwargs]...), dict[:args]...) + end + if isempty(wparams) + if rtype==nothing + MacroTools.@q($arg -> $body) + else + MacroTools.@q(($arg::$rtype) -> $body) + end + else + if rtype === nothing + MacroTools.@q(($arg where {$(wparams...)}) -> $body) + else + MacroTools.@q(($arg::$rtype where {$(wparams...)}) -> $body) + end + end +end + """ getargs_dottilde(x) @@ -6,18 +38,11 @@ Return the arguments `L` and `R`, if `x` is an expression of the form `L .~ R` o """ getargs_dottilde(x) = nothing function getargs_dottilde(expr::Expr) - # Check if the expression is of the form `L .~ R`. - if Meta.isexpr(expr, :call, 3) && expr.args[1] === :.~ - return expr.args[2], expr.args[3] + return MacroTools.@match expr begin + (.~)(L_, R_) => (L, R) + (~).(L_, R_) => (L, R) + x_ => nothing end - - # Check if the expression is of the form `(~).(L, R)`. - if Meta.isexpr(expr, :., 2) && expr.args[1] === :~ && - Meta.isexpr(expr.args[2], :tuple, 2) - return expr.args[2].args[1], expr.args[2].args[2] - end - - return end """ @@ -28,10 +53,10 @@ otherwise. """ getargs_tilde(x) = nothing function getargs_tilde(expr::Expr) - if Meta.isexpr(expr, :call, 3) && expr.args[1] === :~ - return expr.args[2], expr.args[3] + return MacroTools.@match expr begin + (~)(L_, R_) => (L, R) + x_ => nothing end - return end ############################################ diff --git a/test/prob_macro.jl b/test/prob_macro.jl index fb8d6e95f..5601e11de 100644 --- a/test/prob_macro.jl +++ b/test/prob_macro.jl @@ -23,26 +23,27 @@ Random.seed!(129) loglike = logpdf(Normal(mval, 1), xval) logjoint = logprior + loglike - @test logprob"m = mval | model = demo" == logprior - @test logprob"m = mval | x = xval, model = demo" == logprior - @test logprob"x = xval | m = mval, model = demo" == loglike - @test logprob"x = xval, m = mval | model = demo" == logjoint + model = demo(xval) + @test logprob"m = mval | model = model" == logprior + @test logprob"m = mval | x = xval, model = model" == logprior + @test logprob"x = xval | m = mval, model = model" == loglike + @test logprob"x = xval, m = mval | model = model" == logjoint varinfo = VarInfo(demo(xval)) - @test logprob"m = mval | model = demo, varinfo = varinfo" == logprior - @test logprob"m = mval | x = xval, model = demo, varinfo = varinfo" == logprior - @test logprob"x = xval | m = mval, model = demo, varinfo = varinfo" == loglike + @test logprob"m = mval | model = model, varinfo = varinfo" == logprior + @test logprob"m = mval | x = xval, model = model, varinfo = varinfo" == logprior + @test logprob"x = xval | m = mval, model = model, varinfo = varinfo" == loglike varinfo = VarInfo(demo(missing)) - @test logprob"x = xval, m = mval | model = demo, varinfo = varinfo" == logjoint + @test logprob"x = xval, m = mval | model = model, varinfo = varinfo" == logjoint chain = sample(demo(xval), IS(), iters; save_state = true) chain2 = Chains(chain.value, chain.logevidence, chain.name_map, NamedTuple()) lps = logpdf.(Normal.(vec(chain["m"]), 1), xval) @test logprob"x = xval | chain = chain" == lps - @test logprob"x = xval | chain = chain2, model = demo" == lps + @test logprob"x = xval | chain = chain2, model = model" == lps varinfo = VarInfo(demo(xval)) @test logprob"x = xval | chain = chain, varinfo = varinfo" == lps - @test logprob"x = xval | chain = chain2, model = demo, varinfo = varinfo" == lps + @test logprob"x = xval | chain = chain2, model = model, varinfo = varinfo" == lps end @testset "vector" begin @@ -61,13 +62,14 @@ Random.seed!(129) loglike = like(mval, xval) logjoint = logprior + loglike - @test logprob"m = mval | model = demo" == logprior - @test logprob"x = xval | m = mval, model = demo" == loglike - @test logprob"x = xval, m = mval | model = demo" == logjoint + model = demo(xval) + @test logprob"m = mval | model = model" == logprior + @test logprob"x = xval | m = mval, model = model" == loglike + @test logprob"x = xval, m = mval | model = model" == logjoint varinfo = VarInfo(demo(xval)) - @test logprob"m = mval | model = demo, varinfo = varinfo" == logprior - @test logprob"x = xval | m = mval, model = demo, varinfo = varinfo" == loglike + @test logprob"m = mval | model = model, varinfo = varinfo" == logprior + @test logprob"x = xval | m = mval, model = model, varinfo = varinfo" == loglike # Currently, we cannot easily pre-allocate `VarInfo` for vector data chain = sample(demo(xval), HMC(0.5, 1), iters; save_state = true) @@ -78,8 +80,8 @@ Random.seed!(129) like([chain[iter, name, 1] for name in names], xval) end @test logprob"x = xval | chain = chain" == lps - @test logprob"x = xval | chain = chain2, model = demo" == lps + @test logprob"x = xval | chain = chain2, model = model" == lps @test logprob"x = xval | chain = chain, varinfo = varinfo" == lps - @test logprob"x = xval | chain = chain2, model = demo, varinfo = varinfo" == lps + @test logprob"x = xval | chain = chain2, model = model, varinfo = varinfo" == lps end end diff --git a/test/runtests.jl b/test/runtests.jl index d643e38ca..59fb59042 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,9 @@ using ForwardDiff using Tracker using Zygote +using Distributed using Random +using Serialization using Test dir = splitdir(splitdir(pathof(DynamicPPL))[1])[1] @@ -29,6 +31,8 @@ include("test_util.jl") include("threadsafe.jl") + include("serialization.jl") + @testset "compat" begin include(joinpath("compat", "ad.jl")) end diff --git a/test/serialization.jl b/test/serialization.jl new file mode 100644 index 000000000..7afd8f585 --- /dev/null +++ b/test/serialization.jl @@ -0,0 +1,54 @@ +@testset "serialization.jl" begin + Random.seed!(1234) + + @testset "saving and loading" begin + # Save model. + open("gdemo_default.jls", "w") do io + serialize(io, gdemo_default) + end + + # Sample from deserialized model. + gdemo_default_copy = open(deserialize, "gdemo_default.jls", "r") + samples = [gdemo_default_copy() for _ in 1:1_000] + samples_s = first.(samples) + samples_m = last.(samples) + + @test mean(samples_s) ≈ 3 atol=0.1 + @test mean(samples_m) ≈ 0 atol=0.1 + end + + @testset "pmap" begin + # Add worker processes. + addprocs() + @info "serialization test: using $(nworkers()) processes" + + # Load packages on all processes. + @everywhere begin + using DynamicPPL + using Distributions + end + + # Define model on all proceses. + @everywhere @model function model() + m ~ Normal(0, 1) + end + + # Generate `Model` objects on all processes. + models = pmap(_ -> model(), 1:100) + @test models isa Vector{<:Model} + @test length(models) == 100 + + # Sample from model on all processes. + n = 1_000 + samples1 = pmap(_ -> model()(), 1:n) + m = model() + samples2 = pmap(_ -> m(), 1:n) + + for samples in (samples1, samples2) + @test samples isa Vector{Float64} + @test length(samples) == n + @test mean(samples) ≈ 0 atol=0.1 + @test std(samples) ≈ 1 atol=0.1 + end + end +end