Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 772d690

Browse files
committedJun 29, 2020
Remove ModelGen (#134)
This PR removes `ModelGen` completely. The main motivation for it were the issues that @itsdfish and I experienced when working with multiple processes. The following MWE ```julia using Distributed addprocs(4) @Everywhere using DynamicPPL @Everywhere @model model() = begin end pmap(x -> model(), 1:4) ``` fails intermittently > if not all of these [`@model`] evaluations generate same evaluator and generator functions of the same name (i.e., these var"###evaluator#253" and var"###generator#254" functions). I assume one could maybe provoke the error by defining another model first on the main process before calling @Everywhere @model .... (copied from the discussion on Slack) With the changes in this PR, `@model model() = ...` generates only a single function `model` on all workers, and hence there are no issues anymore with the names of the generators and evaluators. The evaluator is created as a closure inside of `model`, which can be serialized and deserialized properly by Julia. So far I haven't been able to reproduce the issues above with this PR. The only user-facing change of the removal of `ModelGen` (apart from that one never has to construct it, which simplifies the docs and the example @denainjs asked about) is that the `logprob` macro now requires to specify `model = m` where `m = model()` instead of `model = model` (since that's just a regular function from which the default arguments etc of the resulting model can't be extracted). It feels slightly weird that the evaluation is not based "exactly" on the specified `Model` instance but that the other parts of `logprob` modify it (which was the reason I guess for using the model generator before here), but on the other hand this weird behaviour already exists when specifying `logprob` with a `Chains` object. (BTW I'm not sure if we should actually use string macros here, maybe regular functions would be nicer.) Additionally, I assume (though haven't tested it) that getting rid of the separate evaluator and generator functions will not only simplify serialization and deserialization when working with multiple processes but also when saving models and chains (see e.g. TuringLang/Turing.jl#1091). Co-authored-by: David Widmann <[email protected]>
1 parent fe04045 commit 772d690

File tree

10 files changed

+311
-277
lines changed

10 files changed

+311
-277
lines changed
 

‎.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
.DS_Store
55
/Manifest.toml
66
/dev/
7+
/test/gdemo_default.jls

‎Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ julia = "1"
2323
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
2424
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
2525
BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
26+
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
2627
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
2728
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
2829
EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2"
@@ -38,6 +39,7 @@ ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
3839
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3940
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
4041
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
42+
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
4143
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
4244
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
4345
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
@@ -48,4 +50,4 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
4850
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4951

5052
[targets]
51-
test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs", "Zygote"]
53+
test = ["AdvancedHMC", "AdvancedMH", "Distributed", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "Serialization", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs", "Zygote"]

‎src/DynamicPPL.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ module DynamicPPL
33
using AbstractMCMC: AbstractSampler, AbstractChains, AbstractModel
44
using Distributions
55
using Bijectors
6-
using MacroTools
76

87
import AbstractMCMC
8+
import MacroTools
99
import ZygoteRules
1010

1111
import Random
@@ -51,25 +51,20 @@ export AbstractVarInfo,
5151
inspace,
5252
subsumes,
5353
# Compiler
54-
ModelGen,
5554
@model,
5655
@varname,
5756
# Utilities
5857
vectorize,
5958
reconstruct,
6059
reconstruct!,
6160
Sample,
62-
Chain,
6361
init,
6462
vectorize,
6563
set_resume!,
6664
# Model
67-
ModelGen,
6865
Model,
6966
getmissings,
7067
getargnames,
71-
getdefaults,
72-
getgenerator,
7368
# Samplers
7469
Sampler,
7570
SampleFromPrior,

‎src/compiler.jl

Lines changed: 98 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,17 @@ Macro to specify a probabilistic model.
4949
If `warn` is `true`, a warning is displayed if internal variable names are used in the model
5050
definition.
5151
52-
# Example
52+
# Examples
5353
5454
Model definition:
5555
5656
```julia
57-
@model function model_generator(x = default_x, y)
57+
@model function model(x, y = 42)
5858
...
5959
end
6060
```
6161
62-
To generate a `Model`, call `model_generator(x_value)`.
62+
To generate a `Model`, call `model(xvalue)` or `model(xvalue, yvalue)`.
6363
"""
6464
macro model(expr, warn=true)
6565
esc(model(expr, warn))
@@ -69,7 +69,9 @@ function model(expr, warn)
6969
modelinfo = build_model_info(expr)
7070

7171
# Generate main body
72-
modelinfo[:modelbody] = generate_mainbody(modelinfo[:body], modelinfo[:modelargs], warn)
72+
modelinfo[:body] = generate_mainbody(
73+
modelinfo[:modeldef][:body], modelinfo[:allargs_exprs], warn
74+
)
7375

7476
return build_output(modelinfo)
7577
end
@@ -80,87 +82,74 @@ end
8082
Builds the `model_info` dictionary from the model's expression.
8183
"""
8284
function build_model_info(input_expr)
83-
# Extract model name (:name), arguments (:args), (:kwargs) and definition (:body)
85+
# Break up the model definition and extract its name, arguments, and function body
8486
modeldef = MacroTools.splitdef(input_expr)
85-
# Function body of the model is empty
87+
88+
# Print a warning if function body of the model is empty
8689
warn_empty(modeldef[:body])
87-
# Construct model_info dictionary
88-
89-
# Extracting the argument symbols from the model definition
90-
combinedargs = vcat(modeldef[:args], modeldef[:kwargs])
91-
arg_syms = map(combinedargs) do arg
92-
# @model demo(x)
93-
if (arg isa Symbol)
94-
arg
95-
# @model demo(::Type{T}) where {T}
96-
elseif MacroTools.@capture(arg, ::Type{T_} = Tval_)
97-
T
98-
# @model demo(x::T = 1)
99-
elseif MacroTools.@capture(arg, x_::T_ = val_)
100-
x
101-
# @model demo(x = 1)
102-
elseif MacroTools.@capture(arg, x_ = val_)
103-
x
104-
else
105-
throw(ArgumentError("Unsupported argument $arg to the `@model` macro."))
106-
end
107-
end
108-
if length(arg_syms) == 0
109-
args_nt = :(NamedTuple())
110-
else
111-
nt_type = Expr(:curly, :NamedTuple,
112-
Expr(:tuple, QuoteNode.(arg_syms)...),
113-
Expr(:curly, :Tuple, [:(Core.Typeof($x)) for x in arg_syms]...)
90+
91+
## Construct model_info dictionary
92+
93+
# Shortcut if the model does not have any arguments
94+
if !haskey(modeldef, :args) && !haskey(modeldef, :kwargs)
95+
modelinfo = Dict(
96+
:allargs_exprs => [],
97+
:allargs_syms => [],
98+
:allargs_namedtuple => NamedTuple(),
99+
:defaults_namedtuple => NamedTuple(),
100+
:modeldef => modeldef,
114101
)
115-
args_nt = Expr(:call, :($namedtuple), nt_type, Expr(:tuple, arg_syms...))
102+
return modelinfo
116103
end
117-
args = map(combinedargs) do arg
118-
if (arg isa Symbol)
119-
arg
120-
elseif MacroTools.@capture(arg, ::Type{T_} = Tval_)
121-
if in(T, modeldef[:whereparams])
122-
S = :Any
123-
else
124-
ind = findfirst(modeldef[:whereparams]) do x
125-
MacroTools.@capture(x, T1_ <: S_) && T1 == T
126-
end
127-
ind !== nothing || throw(ArgumentError("Please make sure type parameters are properly used. Every `Type{T}` argument need to have `T` in the a `where` clause"))
128-
end
129-
Expr(:kw, :($T::Type{<:$S}), Tval)
130-
else
131-
arg
104+
105+
# Extract the positional and keyword arguments from the model definition.
106+
allargs = vcat(modeldef[:args], modeldef[:kwargs])
107+
108+
# Split the argument expressions and the default values.
109+
allargs_exprs_defaults = map(allargs) do arg
110+
MacroTools.@match arg begin
111+
(x_ = val_) => (x, val)
112+
x_ => (x, NO_DEFAULT)
113+
end
114+
end
115+
116+
# Extract the expressions of the arguments, without default values.
117+
allargs_exprs = first.(allargs_exprs_defaults)
118+
119+
# Extract the names of the arguments.
120+
allargs_syms = map(allargs_exprs_defaults) do (arg, _)
121+
MacroTools.@match arg begin
122+
(::Type{T_}) | (name_::Type{T_}) => T
123+
name_::T_ => name
124+
x_ => x
132125
end
133126
end
134-
args_nt = to_namedtuple_expr(arg_syms)
135127

128+
# Build named tuple expression of the argument symbols and variables of the same name.
129+
allargs_namedtuple = to_namedtuple_expr(allargs_syms)
130+
131+
# Extract default values of the positional and keyword arguments.
136132
default_syms = []
137-
default_vals = []
138-
foreach(combinedargs) do arg
139-
# @model demo(::Type{T}) where {T}
140-
if MacroTools.@capture(arg, ::Type{T_} = Tval_)
141-
push!(default_syms, T)
142-
push!(default_vals, Tval)
143-
# @model demo(x::T = 1)
144-
elseif MacroTools.@capture(arg, x_::T_ = val_)
145-
push!(default_syms, x)
146-
push!(default_vals, val)
147-
# @model demo(x = 1)
148-
elseif MacroTools.@capture(arg, x_ = val_)
149-
push!(default_syms, x)
133+
default_vals = []
134+
for (sym, (expr, val)) in zip(allargs_syms, allargs_exprs_defaults)
135+
if val !== NO_DEFAULT
136+
push!(default_syms, sym)
150137
push!(default_vals, val)
151138
end
152139
end
153-
defaults_nt = to_namedtuple_expr(default_syms, default_vals)
154140

155-
modelderiv = Dict(
156-
:modelargs => args,
157-
:modelargsyms => arg_syms,
158-
:modelargsnt => args_nt,
159-
:modeldefaultsnt => defaults_nt,
141+
# Build named tuple expression of the argument symbols with default values.
142+
defaults_namedtuple = to_namedtuple_expr(default_syms, default_vals)
143+
144+
modelinfo = Dict(
145+
:allargs_exprs => allargs_exprs,
146+
:allargs_syms => allargs_syms,
147+
:allargs_namedtuple => allargs_namedtuple,
148+
:defaults_namedtuple => defaults_namedtuple,
149+
:modeldef => modeldef,
160150
)
161-
model_info = merge(modeldef, modelderiv)
162151

163-
return model_info
152+
return modelinfo
164153
end
165154

166155
"""
@@ -312,55 +301,53 @@ hasmissing(T::Type{<:AbstractArray{>:Missing}}) = true
312301
hasmissing(T::Type) = false
313302

314303
"""
315-
build_output(model_info)
304+
build_output(modelinfo)
316305
317306
Builds the output expression.
318307
"""
319-
function build_output(model_info)
320-
# Arguments with default values
321-
args = model_info[:modelargs]
322-
# Argument symbols without default values
323-
arg_syms = model_info[:modelargsyms]
324-
# Arguments namedtuple
325-
args_nt = model_info[:modelargsnt]
326-
# Default values of the arguments
327-
# Arguments namedtuple
328-
defaults_nt = model_info[:modeldefaultsnt]
329-
# Model generator name
330-
model_gen = model_info[:name]
331-
# Main body of the model
332-
main_body = model_info[:modelbody]
333-
334-
unwrap_data_expr = Expr(:block)
335-
for var in arg_syms
336-
push!(unwrap_data_expr.args,
337-
:($var = $(DynamicPPL.matchingvalue)(_sampler, _varinfo, _model.args.$var)))
338-
end
308+
function build_output(modelinfo)
309+
## Build the anonymous evaluator from the user-provided model definition.
310+
311+
# Remove the name.
312+
evaluatordef = deepcopy(modelinfo[:modeldef])
313+
delete!(evaluatordef, :name)
314+
315+
# Add the internal arguments to the user-specified arguments (positional + keywords).
316+
evaluatordef[:args] = vcat(
317+
[
318+
:(_rng::$(Random.AbstractRNG)),
319+
:(_model::$(DynamicPPL.Model)),
320+
:(_varinfo::$(DynamicPPL.AbstractVarInfo)),
321+
:(_sampler::$(DynamicPPL.AbstractSampler)),
322+
:(_context::$(DynamicPPL.AbstractContext)),
323+
],
324+
modelinfo[:allargs_exprs],
325+
)
339326

340-
@gensym(evaluator, generator)
341-
model_gen_constructor = :($(DynamicPPL.ModelGen){$(Tuple(arg_syms))}($generator, $defaults_nt))
327+
# Delete the keyword arguments.
328+
evaluatordef[:kwargs] = []
342329

343-
# construct the user-facing model generator
344-
model_info[:name] = generator
345-
model_info[:body] = :(return $(DynamicPPL.Model)($evaluator, $args_nt, $model_gen_constructor))
346-
generator_expr = MacroTools.combinedef(model_info)
330+
# Replace the user-provided function body with the version created by DynamicPPL.
331+
evaluatordef[:body] = modelinfo[:body]
347332

348-
return quote
349-
function $evaluator(
350-
_rng::$(Random.AbstractRNG),
351-
_model::$(DynamicPPL.Model),
352-
_varinfo::$(DynamicPPL.AbstractVarInfo),
353-
_sampler::$(DynamicPPL.AbstractSampler),
354-
_context::$(DynamicPPL.AbstractContext),
355-
)
356-
$unwrap_data_expr
357-
$main_body
358-
end
333+
## Build the model function.
359334

360-
$(generator_expr)
335+
# Extract the named tuple expression of all arguments and the default values.
336+
allargs_namedtuple = modelinfo[:allargs_namedtuple]
337+
defaults_namedtuple = modelinfo[:defaults_namedtuple]
361338

362-
$(Base).@__doc__ $model_gen = $model_gen_constructor
339+
# Update the function body of the user-specified model.
340+
# We use a name for the anonymous evaluator that does not conflict with other variables.
341+
modeldef = modelinfo[:modeldef]
342+
@gensym evaluator
343+
modeldef[:body] = quote
344+
$evaluator = $(combinedef_anonymous(evaluatordef))
345+
return $(DynamicPPL.Model)(
346+
$evaluator, $allargs_namedtuple, $defaults_namedtuple
347+
)
363348
end
349+
350+
return :($(Base).@__doc__ $(MacroTools.combinedef(modeldef)))
364351
end
365352

366353

‎src/model.jl

Lines changed: 55 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,112 +1,70 @@
11
"""
2-
struct ModelGen{G, defaultnames, Tdefaults}
3-
generator::G
4-
defaults::Tdefaults
5-
end
6-
7-
A `ModelGen` struct with model generator function of type `G`, and default arguments `defaultnames`
8-
with values `Tdefaults`.
9-
"""
10-
struct ModelGen{G, argnames, defaultnames, Tdefaults}
11-
generator::G
12-
defaults::NamedTuple{defaultnames, Tdefaults}
13-
14-
function ModelGen{argnames}(
15-
generator::G,
16-
defaults::NamedTuple{defaultnames, Tdefaults}
17-
) where {G, argnames, defaultnames, Tdefaults}
18-
return new{G, argnames, defaultnames, Tdefaults}(generator, defaults)
2+
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults}
3+
f::F
4+
args::NamedTuple{argnames,Targs}
5+
defaults::NamedTuple{defaultnames,Tdefaults}
196
end
20-
end
21-
22-
(m::ModelGen)(args...; kwargs...) = m.generator(args...; kwargs...)
237
8+
A `Model` struct with model evaluation function of type `F`, arguments of names `argnames`
9+
types `Targs`, default arguments of names `defaultnames` with types `Tdefaults`, and missing
10+
arguments `missings`.
2411
25-
"""
26-
getdefaults(modelgen::ModelGen)
27-
28-
Get a named tuple of the default argument values defined for a model defined by a generating function.
29-
"""
30-
getdefaults(modelgen::ModelGen) = modelgen.defaults
12+
Here `argnames`, `defaultargnames`, and `missings` are tuples of symbols, e.g. `(:a, :b)`.
3113
32-
"""
33-
getargnames(modelgen::ModelGen)
34-
35-
Get a tuple of the argument names of the `modelgen`.
36-
"""
37-
getargnames(model::ModelGen{_G, argnames}) where {argnames, _G} = argnames
38-
39-
40-
41-
"""
42-
struct Model{F, argnames, Targs, missings}
43-
f::F
44-
args::NamedTuple{argnames, Targs}
45-
modelgen::Tgen
46-
end
14+
An argument with a type of `Missing` will be in `missings` by default. However, in
15+
non-traditional use-cases `missings` can be defined differently. All variables in `missings`
16+
are treated as random variables rather than observations.
4717
48-
A `Model` struct with model evaluation function of type `F`, arguments names `argnames`, arguments
49-
types `Targs`, missing arguments `missings`, and corresponding model generator. `argnames` and
50-
`missings` are tuples of symbols, e.g. `(:a, :b)`. An argument with a type of `Missing` will be in
51-
`missings` by default. However, in non-traditional use-cases `missings` can be defined differently.
52-
All variables in `missings` are treated as random variables rather than observations.
18+
The default arguments are used internally when constructing instances of the same model with
19+
different arguments.
5320
54-
# Example
21+
# Examples
5522
5623
```julia
5724
julia> Model(f, (x = 1.0, y = 2.0))
58-
Model{typeof(f),(),(:x, :y),Tuple{Float64,Float64}}((x = 1.0, y = 2.0))
25+
Model{typeof(f),(:x, :y),(),(),Tuple{Float64,Float64},Tuple{}}(f, (x = 1.0, y = 2.0), NamedTuple())
5926
60-
julia> Model{(:y,)}(f, (x = 1.0, y = 2.0))
61-
Model{typeof(f),(:y,),(:x, :y),Tuple{Float64,Float64}}((x = 1.0, y = 2.0))
27+
julia> Model(f, (x = 1.0, y = 2.0), (x = 42,))
28+
Model{typeof(f),(:x, :y),(:x,),(),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,))
29+
30+
julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition of missings
31+
Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,))
6232
```
6333
"""
64-
struct Model{F, argnames, Targs, missings, Tgen} <: AbstractModel
34+
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractModel
6535
f::F
66-
args::NamedTuple{argnames, Targs}
67-
modelgen::Tgen
36+
args::NamedTuple{argnames,Targs}
37+
defaults::NamedTuple{defaultnames,Tdefaults}
6838

6939
"""
70-
Model{missings}(f, args::NamedTuple, modelgen::ModelGen)
40+
Model{missings}(f, args::NamedTuple, defaults::NamedTuple)
7141
72-
Create a model with evalutation function `f` and missing arguments overwritten by `missings`.
42+
Create a model with evaluation function `f` and missing arguments overwritten by `missings`.
7343
"""
7444
function Model{missings}(
7545
f::F,
76-
args::NamedTuple{argnames, Targs},
77-
modelgen::Tgen
78-
) where {missings, F, argnames, Targs, Tgen<:ModelGen}
79-
return new{F, argnames, Targs, missings, Tgen}(f, args, modelgen)
46+
args::NamedTuple{argnames,Targs},
47+
defaults::NamedTuple{defaultnames,Tdefaults},
48+
) where {missings,F,argnames,Targs,defaultnames,Tdefaults}
49+
return new{F,argnames,defaultnames,missings,Targs,Tdefaults}(f, args, defaults)
8050
end
8151
end
8252

8353
"""
84-
Model(f, args::NamedTuple, modelgen::ModelGen)
54+
Model(f, args::NamedTuple[, defaults::NamedTuple = ()])
55+
56+
Create a model with evaluation function `f` and missing arguments deduced from `args`.
8557
86-
Create a model with evalutation function `f` and missing arguments deduced from `args`.
58+
Default arguments `defaults` are used internally when constructing instances of the same
59+
model with different arguments.
8760
"""
8861
@generated function Model(
8962
f::F,
90-
args::NamedTuple{argnames, Targs},
91-
modelgen::ModelGen{_G, argnames}
92-
) where {F, argnames, Targs, _G}
63+
args::NamedTuple{argnames,Targs},
64+
defaults::NamedTuple = NamedTuple(),
65+
) where {F,argnames,Targs}
9366
missings = Tuple(name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing)
94-
return :(Model{$missings}(f, args, modelgen))
95-
end
96-
97-
98-
"""
99-
Model{missings}(modelgen::ModelGen, args::NamedTuple)
100-
101-
Create a copy of the model described by `modelgen(args...)`, with missing arguments
102-
overwritten by `missings`.
103-
"""
104-
function Model{missings}(
105-
modelgen::ModelGen,
106-
args::NamedTuple{argnames, Targs}
107-
) where {missings, argnames, Targs}
108-
model = modelgen(args...)
109-
return Model{missings}(model.f, args, modelgen)
67+
return :(Model{$missings}(f, args, defaults))
11068
end
11169

11270
"""
@@ -154,7 +112,7 @@ function evaluate_threadunsafe(rng, model, varinfo, sampler, context)
154112
if has_eval_num(sampler)
155113
sampler.state.eval_num += 1
156114
end
157-
return model.f(rng, model, varinfo, sampler, context)
115+
return _evaluate(rng, model, varinfo, sampler, context)
158116
end
159117

160118
"""
@@ -174,37 +132,39 @@ function evaluate_threadsafe(rng, model, varinfo, sampler, context)
174132
sampler.state.eval_num += 1
175133
end
176134
wrapper = ThreadSafeVarInfo(varinfo)
177-
result = model.f(rng, model, wrapper, sampler, context)
135+
result = _evaluate(rng, model, wrapper, sampler, context)
178136
setlogp!(varinfo, getlogp(wrapper))
179137
return result
180138
end
181139

140+
"""
141+
_evaluate(rng, model::Model, varinfo, sampler, context)
142+
143+
Evaluate the `model` with the arguments matching the given `sampler` and `varinfo` object.
144+
"""
145+
@generated function _evaluate(rng, model::Model{_F,argnames}, varinfo, sampler, context) where {_F,argnames}
146+
unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames]
147+
return :(model.f(rng, model, varinfo, sampler, context, $(unwrap_args...)))
148+
end
149+
182150
"""
183151
getargnames(model::Model)
184152
185153
Get a tuple of the argument names of the `model`.
186154
"""
187-
getargnames(model::Model{_F, argnames}) where {argnames, _F} = argnames
155+
getargnames(model::Model{_F,argnames}) where {argnames,_F} = argnames
188156

189157

190158
"""
191159
getmissings(model::Model)
192160
193161
Get a tuple of the names of the missing arguments of the `model`.
194162
"""
195-
getmissings(model::Model{_F, _a, _T, missings}) where {missings, _F, _a, _T} = missings
163+
getmissings(model::Model{_F,_a,_d,missings}) where {missings,_F,_a,_d} = missings
196164

197165
getmissing(model::Model) = getmissings(model)
198166
@deprecate getmissing(model) getmissings(model)
199167

200-
201-
"""
202-
getgenerator(model::Model)
203-
204-
Get the model generator associated with `model`.
205-
"""
206-
getgenerator(model::Model) = model.modelgen
207-
208168
"""
209169
logjoint(model::Model, varinfo::AbstractVarInfo)
210170
@@ -214,7 +174,7 @@ See [`logjoint`](@ref) and [`loglikelihood`](@ref).
214174
"""
215175
function logjoint(model::Model, varinfo::AbstractVarInfo)
216176
model(varinfo, SampleFromPrior(), DefaultContext())
217-
return getlogp(varinfo)
177+
return getlogp(varinfo)
218178
end
219179

220180
"""
@@ -226,7 +186,7 @@ See also [`logjoint`](@ref) and [`loglikelihood`](@ref).
226186
"""
227187
function logprior(model::Model, varinfo::AbstractVarInfo)
228188
model(varinfo, SampleFromPrior(), PriorContext())
229-
return getlogp(varinfo)
189+
return getlogp(varinfo)
230190
end
231191

232192
"""
@@ -238,5 +198,5 @@ See also [`logjoint`](@ref) and [`logprior`](@ref).
238198
"""
239199
function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo)
240200
model(varinfo, SampleFromPrior(), LikelihoodContext())
241-
return getlogp(varinfo)
201+
return getlogp(varinfo)
242202
end

‎src/prob_macro.jl

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,24 @@ function get_exprs(str::String)
2222
end
2323

2424
function logprob(ex1, ex2)
25-
ptype, modelgen, vi = probtype(ex1, ex2)
25+
ptype, model, vi = probtype(ex1, ex2)
2626
if ptype isa Val{:prior}
27-
return logprior(ex1, ex2, modelgen, vi)
27+
return logprior(ex1, ex2, model, vi)
2828
elseif ptype isa Val{:likelihood}
29-
return loglikelihood(ex1, ex2, modelgen, vi)
29+
return loglikelihood(ex1, ex2, model, vi)
3030
end
3131
end
3232

3333
function probtype(ntl::NamedTuple{namesl}, ntr::NamedTuple{namesr}) where {namesl, namesr}
3434
if :chain in namesr
3535
if isdefined(ntr.chain.info, :model)
3636
model = ntr.chain.info.model
37-
@assert model isa Model
38-
modelgen = getgenerator(model)
3937
elseif isdefined(ntr, :model)
40-
modelgen = ntr.model
38+
model = ntr.model
4139
else
4240
throw("The model is not defined. Please make sure the model is either saved in the chain or passed on the RHS of |.")
4341
end
42+
@assert model isa Model
4443
if isdefined(ntr.chain.info, :vi)
4544
_vi = ntr.chain.info.vi
4645
@assert _vi isa VarInfo
@@ -52,31 +51,34 @@ function probtype(ntl::NamedTuple{namesl}, ntr::NamedTuple{namesr}) where {names
5251
else
5352
vi = nothing
5453
end
55-
defaults = getdefaults(modelgen)
56-
valid_arg(arg) = isdefined(ntl, arg) || isdefined(ntr, arg) ||
57-
isdefined(defaults, arg) && getfield(defaults, arg) !== missing
58-
@assert all(valid_arg, getargnames(modelgen))
59-
return Val(:likelihood), modelgen, vi
54+
defaults = model.defaults
55+
@assert all(getargnames(model)) do arg
56+
isdefined(ntl, arg) || isdefined(ntr, arg) ||
57+
isdefined(defaults, arg) && getfield(defaults, arg) !== missing
58+
end
59+
return Val(:likelihood), model, vi
6060
else
6161
@assert isdefined(ntr, :model)
62-
modelgen = ntr.model
62+
model = ntr.model
63+
@assert model isa Model
6364
if isdefined(ntr, :varinfo)
6465
_vi = ntr.varinfo
6566
@assert _vi isa VarInfo
6667
vi = TypedVarInfo(_vi)
6768
else
6869
vi = nothing
6970
end
70-
return probtype(ntl, ntr, modelgen), modelgen, vi
71+
return probtype(ntl, ntr, model), model, vi
7172
end
7273
end
74+
7375
function probtype(
7476
left::NamedTuple{leftnames},
7577
right::NamedTuple{rightnames},
76-
modelgen::ModelGen{_G, argnames, defaultnames}
77-
) where {leftnames, rightnames, argnames, defaultnames, _G}
78-
defaults = getdefaults(modelgen)
79-
prior_rhs = all(n -> n in (:model, :varinfo) ||
78+
model::Model{_F,argnames,defaultnames}
79+
) where {leftnames,rightnames,argnames,defaultnames,_F}
80+
defaults = model.defaults
81+
prior_rhs = all(n -> n in (:model, :varinfo) ||
8082
n in argnames && getfield(right, n) !== missing, rightnames)
8183
function get_arg(arg)
8284
if arg in leftnames
@@ -118,7 +120,7 @@ missing_arg_error_msg(arg, ::Nothing) = """Variable $arg is not defined and has
118120
function logprior(
119121
left::NamedTuple,
120122
right::NamedTuple,
121-
modelgen::ModelGen,
123+
_model::Model,
122124
_vi::Union{Nothing, VarInfo}
123125
)
124126
# For model args on the LHS of |, use their passed value but add the symbol to
@@ -133,7 +135,7 @@ function logprior(
133135
# All `observe` and `dot_observe` calls are no-op in the PriorContext
134136

135137
# When all of model args are on the lhs of |, this is also equal to the logjoint.
136-
model = make_prior_model(left, right, modelgen)
138+
model = make_prior_model(left, right, _model)
137139
vi = _vi === nothing ? VarInfo(deepcopy(model), PriorContext()) : _vi
138140
foreach(keys(vi.metadata)) do n
139141
@assert n in keys(left) "Variable $n is not defined."
@@ -145,31 +147,32 @@ end
145147
@generated function make_prior_model(
146148
left::NamedTuple{leftnames},
147149
right::NamedTuple{rightnames},
148-
modelgen::ModelGen{_G, argnames, defaultnames}
149-
) where {leftnames, rightnames, argnames, defaultnames, _G}
150+
model::Model{_F,argnames,defaultnames}
151+
) where {leftnames,rightnames,argnames,defaultnames,_F}
150152
argvals = []
151153
missings = []
152154
warnings = []
153-
155+
154156
for argname in argnames
155157
if argname in leftnames
156158
push!(argvals, :(deepcopy(left.$argname)))
157159
push!(missings, argname)
158160
elseif argname in rightnames
159161
push!(argvals, :(right.$argname))
160162
elseif argname in defaultnames
161-
push!(argvals, :(getdefaults(modelgen).$argname))
163+
push!(argvals, :(model.defaults.$argname))
162164
else
163165
push!(warnings, :(@warn($(warn_msg(argname)))))
164166
push!(argvals, :(nothing))
165167
end
166168
end
167169

168-
# `args` is inserted as properly typed NamedTuple expression;
170+
# `args` is inserted as properly typed NamedTuple expression;
169171
# `missings` is splatted into a tuple at compile time and inserted as literal
170172
return quote
171173
$(warnings...)
172-
Model{$(Tuple(missings))}(modelgen, $(to_namedtuple_expr(argnames, argvals)))
174+
Model{$(Tuple(missings))}(model.f, $(to_namedtuple_expr(argnames, argvals)),
175+
model.defaults)
173176
end
174177
end
175178

@@ -178,10 +181,10 @@ warn_msg(arg) = "Argument $arg is not defined. A value of `nothing` is used."
178181
function Distributions.loglikelihood(
179182
left::NamedTuple,
180183
right::NamedTuple,
181-
modelgen::ModelGen,
184+
_model::Model,
182185
_vi::Union{Nothing, VarInfo},
183186
)
184-
model = make_likelihood_model(left, right, modelgen)
187+
model = make_likelihood_model(left, right, _model)
185188
vi = _vi === nothing ? VarInfo(deepcopy(model)) : _vi
186189
if isdefined(right, :chain)
187190
# Element-wise likelihood for each value in chain
@@ -205,27 +208,28 @@ end
205208
@generated function make_likelihood_model(
206209
left::NamedTuple{leftnames},
207210
right::NamedTuple{rightnames},
208-
modelgen::ModelGen{_G, argnames, defaultnames}
209-
) where {leftnames, rightnames, argnames, defaultnames, _G}
211+
model::Model{_F,argnames,defaultnames},
212+
) where {leftnames,rightnames,argnames,defaultnames,_F}
210213
argvals = []
211214
missings = []
212-
215+
213216
for argname in argnames
214217
if argname in leftnames
215218
push!(argvals, :(left.$argname))
216219
elseif argname in rightnames
217220
push!(argvals, :(right.$argname))
218221
push!(missings, argname)
219222
elseif argname in defaultnames
220-
push!(argvals, :(getdefaults(modelgen).$argname))
223+
push!(argvals, :(model.defaults.$argname))
221224
else
222225
throw("This point should not be reached. Please open an issue in the DynamicPPL.jl repository.")
223226
end
224227
end
225228

226-
# `args` is inserted as properly typed NamedTuple expression;
229+
# `args` is inserted as properly typed NamedTuple expression;
227230
# `missings` is splatted into a tuple at compile time and inserted as literal
228-
return :(Model{$(Tuple(missings))}(modelgen, $(to_namedtuple_expr(argnames, argvals))))
231+
return :(Model{$(Tuple(missings))}(model.f, $(to_namedtuple_expr(argnames, argvals)),
232+
model.defaults))
229233
end
230234

231235
_setval!(vi::TypedVarInfo, c::AbstractChains) = _setval!(vi.metadata, vi, c)

‎src/utils.jl

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,35 @@
1+
# singleton for indicating if no default arguments are present
2+
struct NoDefault end
3+
const NO_DEFAULT = NoDefault()
4+
5+
# FIXME: This is copied from MacroTools and should be removed when a MacroTools release with
6+
# support for anonymous functions is available (> 0.5.5).
7+
function combinedef_anonymous(dict::Dict)
8+
rtype = get(dict, :rtype, nothing)
9+
params = get(dict, :params, [])
10+
wparams = get(dict, :whereparams, [])
11+
body = MacroTools.block(dict[:body])
12+
13+
if isempty(dict[:kwargs])
14+
arg = :($(dict[:args]...),)
15+
else
16+
arg = Expr(:tuple, Expr(:parameters, dict[:kwargs]...), dict[:args]...)
17+
end
18+
if isempty(wparams)
19+
if rtype==nothing
20+
MacroTools.@q($arg -> $body)
21+
else
22+
MacroTools.@q(($arg::$rtype) -> $body)
23+
end
24+
else
25+
if rtype === nothing
26+
MacroTools.@q(($arg where {$(wparams...)}) -> $body)
27+
else
28+
MacroTools.@q(($arg::$rtype where {$(wparams...)}) -> $body)
29+
end
30+
end
31+
end
32+
133
"""
234
@addlogprob!(ex)
335
@@ -17,18 +49,11 @@ Return the arguments `L` and `R`, if `x` is an expression of the form `L .~ R` o
1749
"""
1850
getargs_dottilde(x) = nothing
1951
function getargs_dottilde(expr::Expr)
20-
# Check if the expression is of the form `L .~ R`.
21-
if Meta.isexpr(expr, :call, 3) && expr.args[1] === :.~
22-
return expr.args[2], expr.args[3]
52+
return MacroTools.@match expr begin
53+
(.~)(L_, R_) => (L, R)
54+
(~).(L_, R_) => (L, R)
55+
x_ => nothing
2356
end
24-
25-
# Check if the expression is of the form `(~).(L, R)`.
26-
if Meta.isexpr(expr, :., 2) && expr.args[1] === :~ &&
27-
Meta.isexpr(expr.args[2], :tuple, 2)
28-
return expr.args[2].args[1], expr.args[2].args[2]
29-
end
30-
31-
return
3257
end
3358

3459
"""
@@ -39,10 +64,10 @@ otherwise.
3964
"""
4065
getargs_tilde(x) = nothing
4166
function getargs_tilde(expr::Expr)
42-
if Meta.isexpr(expr, :call, 3) && expr.args[1] === :~
43-
return expr.args[2], expr.args[3]
67+
return MacroTools.@match expr begin
68+
(~)(L_, R_) => (L, R)
69+
x_ => nothing
4470
end
45-
return
4671
end
4772

4873
############################################

‎test/prob_macro.jl

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,26 +23,27 @@ Random.seed!(129)
2323
loglike = logpdf(Normal(mval, 1), xval)
2424
logjoint = logprior + loglike
2525

26-
@test logprob"m = mval | model = demo" == logprior
27-
@test logprob"m = mval | x = xval, model = demo" == logprior
28-
@test logprob"x = xval | m = mval, model = demo" == loglike
29-
@test logprob"x = xval, m = mval | model = demo" == logjoint
26+
model = demo(xval)
27+
@test logprob"m = mval | model = model" == logprior
28+
@test logprob"m = mval | x = xval, model = model" == logprior
29+
@test logprob"x = xval | m = mval, model = model" == loglike
30+
@test logprob"x = xval, m = mval | model = model" == logjoint
3031

3132
varinfo = VarInfo(demo(xval))
32-
@test logprob"m = mval | model = demo, varinfo = varinfo" == logprior
33-
@test logprob"m = mval | x = xval, model = demo, varinfo = varinfo" == logprior
34-
@test logprob"x = xval | m = mval, model = demo, varinfo = varinfo" == loglike
33+
@test logprob"m = mval | model = model, varinfo = varinfo" == logprior
34+
@test logprob"m = mval | x = xval, model = model, varinfo = varinfo" == logprior
35+
@test logprob"x = xval | m = mval, model = model, varinfo = varinfo" == loglike
3536
varinfo = VarInfo(demo(missing))
36-
@test logprob"x = xval, m = mval | model = demo, varinfo = varinfo" == logjoint
37+
@test logprob"x = xval, m = mval | model = model, varinfo = varinfo" == logjoint
3738

3839
chain = sample(demo(xval), IS(), iters; save_state = true)
3940
chain2 = Chains(chain.value, chain.logevidence, chain.name_map, NamedTuple())
4041
lps = logpdf.(Normal.(vec(chain["m"]), 1), xval)
4142
@test logprob"x = xval | chain = chain" == lps
42-
@test logprob"x = xval | chain = chain2, model = demo" == lps
43+
@test logprob"x = xval | chain = chain2, model = model" == lps
4344
varinfo = VarInfo(demo(xval))
4445
@test logprob"x = xval | chain = chain, varinfo = varinfo" == lps
45-
@test logprob"x = xval | chain = chain2, model = demo, varinfo = varinfo" == lps
46+
@test logprob"x = xval | chain = chain2, model = model, varinfo = varinfo" == lps
4647
end
4748

4849
@testset "vector" begin
@@ -61,13 +62,14 @@ Random.seed!(129)
6162
loglike = like(mval, xval)
6263
logjoint = logprior + loglike
6364

64-
@test logprob"m = mval | model = demo" == logprior
65-
@test logprob"x = xval | m = mval, model = demo" == loglike
66-
@test logprob"x = xval, m = mval | model = demo" == logjoint
65+
model = demo(xval)
66+
@test logprob"m = mval | model = model" == logprior
67+
@test logprob"x = xval | m = mval, model = model" == loglike
68+
@test logprob"x = xval, m = mval | model = model" == logjoint
6769

6870
varinfo = VarInfo(demo(xval))
69-
@test logprob"m = mval | model = demo, varinfo = varinfo" == logprior
70-
@test logprob"x = xval | m = mval, model = demo, varinfo = varinfo" == loglike
71+
@test logprob"m = mval | model = model, varinfo = varinfo" == logprior
72+
@test logprob"x = xval | m = mval, model = model, varinfo = varinfo" == loglike
7173
# Currently, we cannot easily pre-allocate `VarInfo` for vector data
7274

7375
chain = sample(demo(xval), HMC(0.5, 1), iters; save_state = true)
@@ -78,8 +80,8 @@ Random.seed!(129)
7880
like([chain[iter, name, 1] for name in names], xval)
7981
end
8082
@test logprob"x = xval | chain = chain" == lps
81-
@test logprob"x = xval | chain = chain2, model = demo" == lps
83+
@test logprob"x = xval | chain = chain2, model = model" == lps
8284
@test logprob"x = xval | chain = chain, varinfo = varinfo" == lps
83-
@test logprob"x = xval | chain = chain2, model = demo, varinfo = varinfo" == lps
85+
@test logprob"x = xval | chain = chain2, model = model, varinfo = varinfo" == lps
8486
end
8587
end

‎test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ using ForwardDiff
44
using Tracker
55
using Zygote
66

7+
using Distributed
78
using Random
9+
using Serialization
810
using Test
911

1012
dir = splitdir(splitdir(pathof(DynamicPPL))[1])[1]
@@ -29,6 +31,8 @@ include("test_util.jl")
2931

3032
include("threadsafe.jl")
3133

34+
include("serialization.jl")
35+
3236
@testset "compat" begin
3337
include(joinpath("compat", "ad.jl"))
3438
end

‎test/serialization.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
@testset "serialization.jl" begin
2+
Random.seed!(1234)
3+
4+
@testset "saving and loading" begin
5+
# Save model.
6+
open("gdemo_default.jls", "w") do io
7+
serialize(io, gdemo_default)
8+
end
9+
10+
# Sample from deserialized model.
11+
gdemo_default_copy = open(deserialize, "gdemo_default.jls", "r")
12+
samples = [gdemo_default_copy() for _ in 1:1_000]
13+
samples_s = first.(samples)
14+
samples_m = last.(samples)
15+
16+
@test mean(samples_s) 3 atol=0.1
17+
@test mean(samples_m) 0 atol=0.1
18+
end
19+
20+
@testset "pmap" begin
21+
# Add worker processes.
22+
addprocs()
23+
@info "serialization test: using $(nworkers()) processes"
24+
25+
# Load packages on all processes.
26+
@everywhere begin
27+
using DynamicPPL
28+
using Distributions
29+
end
30+
31+
# Define model on all proceses.
32+
@everywhere @model function model()
33+
m ~ Normal(0, 1)
34+
end
35+
36+
# Generate `Model` objects on all processes.
37+
models = pmap(_ -> model(), 1:100)
38+
@test models isa Vector{<:Model}
39+
@test length(models) == 100
40+
41+
# Sample from model on all processes.
42+
n = 1_000
43+
samples1 = pmap(_ -> model()(), 1:n)
44+
m = model()
45+
samples2 = pmap(_ -> m(), 1:n)
46+
47+
for samples in (samples1, samples2)
48+
@test samples isa Vector{Float64}
49+
@test length(samples) == n
50+
@test mean(samples) 0 atol=0.1
51+
@test std(samples) 1 atol=0.1
52+
end
53+
end
54+
end

0 commit comments

Comments
 (0)
Please sign in to comment.