Skip to content

Commit bde1f74

Browse files
committed
Remove ModelGen (#134)
This PR removes `ModelGen` completely. The main motivation for it were the issues that @itsdfish and I experienced when working with multiple processes. The following MWE ```julia using Distributed addprocs(4) @Everywhere using DynamicPPL @Everywhere @model model() = begin end pmap(x -> model(), 1:4) ``` fails intermittently > if not all of these [`@model`] evaluations generate same evaluator and generator functions of the same name (i.e., these var"###evaluator#253" and var"###generator#254" functions). I assume one could maybe provoke the error by defining another model first on the main process before calling @Everywhere @model .... (copied from the discussion on Slack) With the changes in this PR, `@model model() = ...` generates only a single function `model` on all workers, and hence there are no issues anymore with the names of the generators and evaluators. The evaluator is created as a closure inside of `model`, which can be serialized and deserialized properly by Julia. So far I haven't been able to reproduce the issues above with this PR. The only user-facing change of the removal of `ModelGen` (apart from that one never has to construct it, which simplifies the docs and the example @denainjs asked about) is that the `logprob` macro now requires to specify `model = m` where `m = model()` instead of `model = model` (since that's just a regular function from which the default arguments etc of the resulting model can't be extracted). It feels slightly weird that the evaluation is not based "exactly" on the specified `Model` instance but that the other parts of `logprob` modify it (which was the reason I guess for using the model generator before here), but on the other hand this weird behaviour already exists when specifying `logprob` with a `Chains` object. (BTW I'm not sure if we should actually use string macros here, maybe regular functions would be nicer.) Additionally, I assume (though haven't tested it) that getting rid of the separate evaluator and generator functions will not only simplify serialization and deserialization when working with multiple processes but also when saving models and chains (see e.g. TuringLang/Turing.jl#1091). Co-authored-by: David Widmann <[email protected]>
1 parent fe04045 commit bde1f74

10 files changed

+311
-277
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
.DS_Store
55
/Manifest.toml
66
/dev/
7+
/test/gdemo_default.jls

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ julia = "1"
2323
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
2424
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
2525
BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
26+
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
2627
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
2728
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
2829
EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2"
@@ -38,6 +39,7 @@ ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
3839
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3940
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
4041
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
42+
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
4143
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
4244
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
4345
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
@@ -48,4 +50,4 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
4850
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4951

5052
[targets]
51-
test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs", "Zygote"]
53+
test = ["AdvancedHMC", "AdvancedMH", "Distributed", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "Serialization", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs", "Zygote"]

src/DynamicPPL.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ module DynamicPPL
33
using AbstractMCMC: AbstractSampler, AbstractChains, AbstractModel
44
using Distributions
55
using Bijectors
6-
using MacroTools
76

87
import AbstractMCMC
8+
import MacroTools
99
import ZygoteRules
1010

1111
import Random
@@ -51,25 +51,20 @@ export AbstractVarInfo,
5151
inspace,
5252
subsumes,
5353
# Compiler
54-
ModelGen,
5554
@model,
5655
@varname,
5756
# Utilities
5857
vectorize,
5958
reconstruct,
6059
reconstruct!,
6160
Sample,
62-
Chain,
6361
init,
6462
vectorize,
6563
set_resume!,
6664
# Model
67-
ModelGen,
6865
Model,
6966
getmissings,
7067
getargnames,
71-
getdefaults,
72-
getgenerator,
7368
# Samplers
7469
Sampler,
7570
SampleFromPrior,

src/compiler.jl

Lines changed: 98 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,17 @@ Macro to specify a probabilistic model.
4949
If `warn` is `true`, a warning is displayed if internal variable names are used in the model
5050
definition.
5151
52-
# Example
52+
# Examples
5353
5454
Model definition:
5555
5656
```julia
57-
@model function model_generator(x = default_x, y)
57+
@model function model(x, y = 42)
5858
...
5959
end
6060
```
6161
62-
To generate a `Model`, call `model_generator(x_value)`.
62+
To generate a `Model`, call `model(xvalue)` or `model(xvalue, yvalue)`.
6363
"""
6464
macro model(expr, warn=true)
6565
esc(model(expr, warn))
@@ -69,7 +69,9 @@ function model(expr, warn)
6969
modelinfo = build_model_info(expr)
7070

7171
# Generate main body
72-
modelinfo[:modelbody] = generate_mainbody(modelinfo[:body], modelinfo[:modelargs], warn)
72+
modelinfo[:body] = generate_mainbody(
73+
modelinfo[:modeldef][:body], modelinfo[:allargs_exprs], warn
74+
)
7375

7476
return build_output(modelinfo)
7577
end
@@ -80,87 +82,74 @@ end
8082
Builds the `model_info` dictionary from the model's expression.
8183
"""
8284
function build_model_info(input_expr)
83-
# Extract model name (:name), arguments (:args), (:kwargs) and definition (:body)
85+
# Break up the model definition and extract its name, arguments, and function body
8486
modeldef = MacroTools.splitdef(input_expr)
85-
# Function body of the model is empty
87+
88+
# Print a warning if function body of the model is empty
8689
warn_empty(modeldef[:body])
87-
# Construct model_info dictionary
88-
89-
# Extracting the argument symbols from the model definition
90-
combinedargs = vcat(modeldef[:args], modeldef[:kwargs])
91-
arg_syms = map(combinedargs) do arg
92-
# @model demo(x)
93-
if (arg isa Symbol)
94-
arg
95-
# @model demo(::Type{T}) where {T}
96-
elseif MacroTools.@capture(arg, ::Type{T_} = Tval_)
97-
T
98-
# @model demo(x::T = 1)
99-
elseif MacroTools.@capture(arg, x_::T_ = val_)
100-
x
101-
# @model demo(x = 1)
102-
elseif MacroTools.@capture(arg, x_ = val_)
103-
x
104-
else
105-
throw(ArgumentError("Unsupported argument $arg to the `@model` macro."))
106-
end
107-
end
108-
if length(arg_syms) == 0
109-
args_nt = :(NamedTuple())
110-
else
111-
nt_type = Expr(:curly, :NamedTuple,
112-
Expr(:tuple, QuoteNode.(arg_syms)...),
113-
Expr(:curly, :Tuple, [:(Core.Typeof($x)) for x in arg_syms]...)
90+
91+
## Construct model_info dictionary
92+
93+
# Shortcut if the model does not have any arguments
94+
if !haskey(modeldef, :args) && !haskey(modeldef, :kwargs)
95+
modelinfo = Dict(
96+
:allargs_exprs => [],
97+
:allargs_syms => [],
98+
:allargs_namedtuple => NamedTuple(),
99+
:defaults_namedtuple => NamedTuple(),
100+
:modeldef => modeldef,
114101
)
115-
args_nt = Expr(:call, :($namedtuple), nt_type, Expr(:tuple, arg_syms...))
102+
return modelinfo
116103
end
117-
args = map(combinedargs) do arg
118-
if (arg isa Symbol)
119-
arg
120-
elseif MacroTools.@capture(arg, ::Type{T_} = Tval_)
121-
if in(T, modeldef[:whereparams])
122-
S = :Any
123-
else
124-
ind = findfirst(modeldef[:whereparams]) do x
125-
MacroTools.@capture(x, T1_ <: S_) && T1 == T
126-
end
127-
ind !== nothing || throw(ArgumentError("Please make sure type parameters are properly used. Every `Type{T}` argument need to have `T` in the a `where` clause"))
128-
end
129-
Expr(:kw, :($T::Type{<:$S}), Tval)
130-
else
131-
arg
104+
105+
# Extract the positional and keyword arguments from the model definition.
106+
allargs = vcat(modeldef[:args], modeldef[:kwargs])
107+
108+
# Split the argument expressions and the default values.
109+
allargs_exprs_defaults = map(allargs) do arg
110+
MacroTools.@match arg begin
111+
(x_ = val_) => (x, val)
112+
x_ => (x, NO_DEFAULT)
113+
end
114+
end
115+
116+
# Extract the expressions of the arguments, without default values.
117+
allargs_exprs = first.(allargs_exprs_defaults)
118+
119+
# Extract the names of the arguments.
120+
allargs_syms = map(allargs_exprs_defaults) do (arg, _)
121+
MacroTools.@match arg begin
122+
(::Type{T_}) | (name_::Type{T_}) => T
123+
name_::T_ => name
124+
x_ => x
132125
end
133126
end
134-
args_nt = to_namedtuple_expr(arg_syms)
135127

128+
# Build named tuple expression of the argument symbols and variables of the same name.
129+
allargs_namedtuple = to_namedtuple_expr(allargs_syms)
130+
131+
# Extract default values of the positional and keyword arguments.
136132
default_syms = []
137-
default_vals = []
138-
foreach(combinedargs) do arg
139-
# @model demo(::Type{T}) where {T}
140-
if MacroTools.@capture(arg, ::Type{T_} = Tval_)
141-
push!(default_syms, T)
142-
push!(default_vals, Tval)
143-
# @model demo(x::T = 1)
144-
elseif MacroTools.@capture(arg, x_::T_ = val_)
145-
push!(default_syms, x)
146-
push!(default_vals, val)
147-
# @model demo(x = 1)
148-
elseif MacroTools.@capture(arg, x_ = val_)
149-
push!(default_syms, x)
133+
default_vals = []
134+
for (sym, (expr, val)) in zip(allargs_syms, allargs_exprs_defaults)
135+
if val !== NO_DEFAULT
136+
push!(default_syms, sym)
150137
push!(default_vals, val)
151138
end
152139
end
153-
defaults_nt = to_namedtuple_expr(default_syms, default_vals)
154140

155-
modelderiv = Dict(
156-
:modelargs => args,
157-
:modelargsyms => arg_syms,
158-
:modelargsnt => args_nt,
159-
:modeldefaultsnt => defaults_nt,
141+
# Build named tuple expression of the argument symbols with default values.
142+
defaults_namedtuple = to_namedtuple_expr(default_syms, default_vals)
143+
144+
modelinfo = Dict(
145+
:allargs_exprs => allargs_exprs,
146+
:allargs_syms => allargs_syms,
147+
:allargs_namedtuple => allargs_namedtuple,
148+
:defaults_namedtuple => defaults_namedtuple,
149+
:modeldef => modeldef,
160150
)
161-
model_info = merge(modeldef, modelderiv)
162151

163-
return model_info
152+
return modelinfo
164153
end
165154

166155
"""
@@ -312,55 +301,53 @@ hasmissing(T::Type{<:AbstractArray{>:Missing}}) = true
312301
hasmissing(T::Type) = false
313302

314303
"""
315-
build_output(model_info)
304+
build_output(modelinfo)
316305
317306
Builds the output expression.
318307
"""
319-
function build_output(model_info)
320-
# Arguments with default values
321-
args = model_info[:modelargs]
322-
# Argument symbols without default values
323-
arg_syms = model_info[:modelargsyms]
324-
# Arguments namedtuple
325-
args_nt = model_info[:modelargsnt]
326-
# Default values of the arguments
327-
# Arguments namedtuple
328-
defaults_nt = model_info[:modeldefaultsnt]
329-
# Model generator name
330-
model_gen = model_info[:name]
331-
# Main body of the model
332-
main_body = model_info[:modelbody]
333-
334-
unwrap_data_expr = Expr(:block)
335-
for var in arg_syms
336-
push!(unwrap_data_expr.args,
337-
:($var = $(DynamicPPL.matchingvalue)(_sampler, _varinfo, _model.args.$var)))
338-
end
308+
function build_output(modelinfo)
309+
## Build the anonymous evaluator from the user-provided model definition.
310+
311+
# Remove the name.
312+
evaluatordef = deepcopy(modelinfo[:modeldef])
313+
delete!(evaluatordef, :name)
314+
315+
# Add the internal arguments to the user-specified arguments (positional + keywords).
316+
evaluatordef[:args] = vcat(
317+
[
318+
:(_rng::$(Random.AbstractRNG)),
319+
:(_model::$(DynamicPPL.Model)),
320+
:(_varinfo::$(DynamicPPL.AbstractVarInfo)),
321+
:(_sampler::$(DynamicPPL.AbstractSampler)),
322+
:(_context::$(DynamicPPL.AbstractContext)),
323+
],
324+
modelinfo[:allargs_exprs],
325+
)
339326

340-
@gensym(evaluator, generator)
341-
model_gen_constructor = :($(DynamicPPL.ModelGen){$(Tuple(arg_syms))}($generator, $defaults_nt))
327+
# Delete the keyword arguments.
328+
evaluatordef[:kwargs] = []
342329

343-
# construct the user-facing model generator
344-
model_info[:name] = generator
345-
model_info[:body] = :(return $(DynamicPPL.Model)($evaluator, $args_nt, $model_gen_constructor))
346-
generator_expr = MacroTools.combinedef(model_info)
330+
# Replace the user-provided function body with the version created by DynamicPPL.
331+
evaluatordef[:body] = modelinfo[:body]
347332

348-
return quote
349-
function $evaluator(
350-
_rng::$(Random.AbstractRNG),
351-
_model::$(DynamicPPL.Model),
352-
_varinfo::$(DynamicPPL.AbstractVarInfo),
353-
_sampler::$(DynamicPPL.AbstractSampler),
354-
_context::$(DynamicPPL.AbstractContext),
355-
)
356-
$unwrap_data_expr
357-
$main_body
358-
end
333+
## Build the model function.
359334

360-
$(generator_expr)
335+
# Extract the named tuple expression of all arguments and the default values.
336+
allargs_namedtuple = modelinfo[:allargs_namedtuple]
337+
defaults_namedtuple = modelinfo[:defaults_namedtuple]
361338

362-
$(Base).@__doc__ $model_gen = $model_gen_constructor
339+
# Update the function body of the user-specified model.
340+
# We use a name for the anonymous evaluator that does not conflict with other variables.
341+
modeldef = modelinfo[:modeldef]
342+
@gensym evaluator
343+
modeldef[:body] = quote
344+
$evaluator = $(combinedef_anonymous(evaluatordef))
345+
return $(DynamicPPL.Model)(
346+
$evaluator, $allargs_namedtuple, $defaults_namedtuple
347+
)
363348
end
349+
350+
return :($(Base).@__doc__ $(MacroTools.combinedef(modeldef)))
364351
end
365352

366353

0 commit comments

Comments
 (0)