Skip to content

Commit 67b5873

Browse files
committed
Use ExprTools.combinedef for building a correctly typed evaluator
1 parent c6a9f51 commit 67b5873

File tree

4 files changed

+80
-36
lines changed

4 files changed

+80
-36
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version = "0.8.1"
77
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
88
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
99
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
10+
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
1011
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1112
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1213
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
@@ -15,6 +16,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1516
AbstractMCMC = "1"
1617
Bijectors = "0.5.2, 0.6, 0.7, 0.8"
1718
Distributions = "0.22, 0.23"
19+
ExprTools = "0.1.1"
1820
MacroTools = "0.5.1"
1921
ZygoteRules = "0.2"
2022
julia = "1"

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using Bijectors
66
using MacroTools
77

88
import AbstractMCMC
9+
import ExprTools
910
import ZygoteRules
1011

1112
import Random

src/compiler.jl

Lines changed: 65 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,27 @@ end
8080
Builds the `model_info` dictionary from the model's expression.
8181
"""
8282
function build_model_info(input_expr)
83-
# Extract model name (:name), arguments (:args), (:kwargs) and definition (:body)
84-
modeldef = MacroTools.splitdef(input_expr)
85-
# Function body of the model is empty
83+
# Break up the model definition and extract its name, arguments, and function body
84+
modeldef = ExprTools.splitdef(input_expr)
85+
86+
# Print a warning if function body of the model is empty
8687
warn_empty(modeldef[:body])
87-
# Construct model_info dictionary
88+
89+
## Construct model_info dictionary
90+
91+
# Shortcut if the model does not have any arguments
92+
if !haskey(modeldef, :args)
93+
modelinfo = Dict(
94+
:name => modeldef[:name],
95+
:main_body => modeldef[:body],
96+
:arg_syms => [],
97+
:args_nt => NamedTuple(),
98+
:defaults_nt => NamedTuple(),
99+
:args => [],
100+
:modeldef => modeldef,
101+
)
102+
return modelinfo
103+
end
88104

89105
# Extracting the argument symbols from the model definition
90106
arg_syms = map(modeldef[:args]) do arg
@@ -158,7 +174,7 @@ function build_model_info(input_expr)
158174
:args_nt => args_nt,
159175
:defaults_nt => defaults_nt,
160176
:args => args,
161-
:whereparams => modeldef[:whereparams]
177+
:modeldef => modeldef,
162178
)
163179

164180
return model_info
@@ -318,45 +334,60 @@ hasmissing(T::Type) = false
318334
Builds the output expression.
319335
"""
320336
function build_output(model_info)
321-
# Arguments with default values
337+
## Build the anonymous evaluator from the user-provided model definition
338+
339+
# Remove the name and use `function (....)` syntax
340+
modeldef = model_info[:modeldef]
341+
delete!(modeldef, :name)
342+
modeldef[:head] = :function
343+
344+
# Define the input arguments (positional + keyword arguments), without default values
345+
origargs = map(vcat(get(modeldef, :args, Any[]), get(modeldef, :kwargs, Any[]))) do arg
346+
Meta.isexpr(arg, :kw) && length(arg.args) >= 1 ? arg.args[1] : arg
347+
end
348+
349+
# Add our own arguments
350+
newargs = Any[:(_rng::$(Random.AbstractRNG)),
351+
:(_model::$(DynamicPPL.Model)),
352+
:(_varinfo::$(DynamicPPL.AbstractVarInfo)),
353+
:(_sampler::$(DynamicPPL.AbstractSampler)),
354+
:(_context::$(DynamicPPL.AbstractContext))]
355+
combinedargs = vcat(newargs, origargs)
356+
357+
# Delete keyword arguments and update positional arguments
358+
delete!(modeldef, :kwargs)
359+
modeldef[:args] = combinedargs
360+
361+
# Replace function body
362+
modeldef[:body] = model_info[:main_body]
363+
364+
## Extract other relevant information
365+
366+
# All arguments with default values (if existent)
322367
args = model_info[:args]
323-
# Argument symbols without default values
324-
arg_syms = model_info[:arg_syms]
325-
# Arguments namedtuple
368+
# Named tuple of all arguments
326369
args_nt = model_info[:args_nt]
327-
# Default values of the arguments
328-
# Arguments namedtuple
370+
371+
# Named tuple of the default values of the arguments
329372
defaults_nt = model_info[:defaults_nt]
330-
# Where parameters
331-
whereparams = model_info[:whereparams]
332-
# Model generator name
373+
374+
# Model name
333375
model = model_info[:name]
334-
# Main body of the model
335-
main_body = model_info[:main_body]
336376

337-
unwrap_data_expr = Expr(:block)
338-
for var in arg_syms
339-
push!(unwrap_data_expr.args,
340-
:($var = $(DynamicPPL.matchingvalue)(_sampler, _varinfo, _model.args.$var)))
377+
# Define model definition with only keyword arguments
378+
if isempty(args)
379+
model_kwform = ()
380+
else
381+
# All arguments without default values (i.e., only symbols)
382+
arg_syms = model_info[:arg_syms]
383+
384+
model_kwform = (:($model(; $(args...)) = $model($(arg_syms...))),)
341385
end
342386

343-
model_kwform = isempty(args) ? () : (:($model(;$(args...)) = $model($(arg_syms...))),)
344387
@gensym(evaluator)
345-
346388
return quote
347389
$(Base).@__doc__ function $model($(args...))
348-
$evaluator = let
349-
function (
350-
_rng::$(Random.AbstractRNG),
351-
_model::$(DynamicPPL.Model),
352-
_varinfo::$(DynamicPPL.AbstractVarInfo),
353-
_sampler::$(DynamicPPL.AbstractSampler),
354-
_context::$(DynamicPPL.AbstractContext),
355-
)
356-
$unwrap_data_expr
357-
$main_body
358-
end
359-
end
390+
$evaluator = $(ExprTools.combinedef(modeldef))
360391
return $(DynamicPPL.Model)($evaluator, $args_nt, $defaults_nt)
361392
end
362393
$(model_kwform...)

src/model.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ function evaluate_threadunsafe(rng, model, varinfo, sampler, context)
112112
if has_eval_num(sampler)
113113
sampler.state.eval_num += 1
114114
end
115-
return model.f(rng, model, varinfo, sampler, context)
115+
return _evaluate(rng, model, varinfo, sampler, context)
116116
end
117117

118118
"""
@@ -132,11 +132,21 @@ function evaluate_threadsafe(rng, model, varinfo, sampler, context)
132132
sampler.state.eval_num += 1
133133
end
134134
wrapper = ThreadSafeVarInfo(varinfo)
135-
result = model.f(rng, model, wrapper, sampler, context)
135+
result = _evaluate(rng, model, wrapper, sampler, context)
136136
setlogp!(varinfo, getlogp(wrapper))
137137
return result
138138
end
139139

140+
"""
141+
_evaluate(rng, model::Model, varinfo, sampler, context)
142+
143+
Evaluate the `model` with the arguments matching the given `sampler` and `varinfo` object.
144+
"""
145+
@generated function _evaluate(rng, model::Model{_F,argnames}, varinfo, sampler, context) where {_F,argnames}
146+
unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames]
147+
return :(model.f(rng, model, varinfo, sampler, context, $(unwrap_args...)))
148+
end
149+
140150
"""
141151
getargnames(model::Model)
142152

0 commit comments

Comments
 (0)