|
80 | 80 | Builds the `model_info` dictionary from the model's expression.
|
81 | 81 | """
|
82 | 82 | function build_model_info(input_expr)
|
83 |
| - # Extract model name (:name), arguments (:args), (:kwargs) and definition (:body) |
84 |
| - modeldef = MacroTools.splitdef(input_expr) |
85 |
| - # Function body of the model is empty |
| 83 | + # Break up the model definition and extract its name, arguments, and function body |
| 84 | + modeldef = ExprTools.splitdef(input_expr) |
| 85 | + |
| 86 | + # Print a warning if function body of the model is empty |
86 | 87 | warn_empty(modeldef[:body])
|
87 |
| - # Construct model_info dictionary |
| 88 | + |
| 89 | + ## Construct model_info dictionary |
| 90 | + |
| 91 | + # Shortcut if the model does not have any arguments |
| 92 | + if !haskey(modeldef, :args) |
| 93 | + modelinfo = Dict( |
| 94 | + :name => modeldef[:name], |
| 95 | + :main_body => modeldef[:body], |
| 96 | + :arg_syms => [], |
| 97 | + :args_nt => NamedTuple(), |
| 98 | + :defaults_nt => NamedTuple(), |
| 99 | + :args => [], |
| 100 | + :modeldef => modeldef, |
| 101 | + ) |
| 102 | + return modelinfo |
| 103 | + end |
88 | 104 |
|
89 | 105 | # Extracting the argument symbols from the model definition
|
90 | 106 | arg_syms = map(modeldef[:args]) do arg
|
@@ -158,7 +174,7 @@ function build_model_info(input_expr)
|
158 | 174 | :args_nt => args_nt,
|
159 | 175 | :defaults_nt => defaults_nt,
|
160 | 176 | :args => args,
|
161 |
| - :whereparams => modeldef[:whereparams] |
| 177 | + :modeldef => modeldef, |
162 | 178 | )
|
163 | 179 |
|
164 | 180 | return model_info
|
@@ -318,45 +334,60 @@ hasmissing(T::Type) = false
|
318 | 334 | Builds the output expression.
|
319 | 335 | """
|
320 | 336 | function build_output(model_info)
|
321 |
| - # Arguments with default values |
| 337 | + ## Build the anonymous evaluator from the user-provided model definition |
| 338 | + |
| 339 | + # Remove the name and use `function (....)` syntax |
| 340 | + modeldef = model_info[:modeldef] |
| 341 | + delete!(modeldef, :name) |
| 342 | + modeldef[:head] = :function |
| 343 | + |
| 344 | + # Define the input arguments (positional + keyword arguments), without default values |
| 345 | + origargs = map(vcat(get(modeldef, :args, Any[]), get(modeldef, :kwargs, Any[]))) do arg |
| 346 | + Meta.isexpr(arg, :kw) && length(arg.args) >= 1 ? arg.args[1] : arg |
| 347 | + end |
| 348 | + |
| 349 | + # Add our own arguments |
| 350 | + newargs = Any[:(_rng::$(Random.AbstractRNG)), |
| 351 | + :(_model::$(DynamicPPL.Model)), |
| 352 | + :(_varinfo::$(DynamicPPL.AbstractVarInfo)), |
| 353 | + :(_sampler::$(DynamicPPL.AbstractSampler)), |
| 354 | + :(_context::$(DynamicPPL.AbstractContext))] |
| 355 | + combinedargs = vcat(newargs, origargs) |
| 356 | + |
| 357 | + # Delete keyword arguments and update positional arguments |
| 358 | + delete!(modeldef, :kwargs) |
| 359 | + modeldef[:args] = combinedargs |
| 360 | + |
| 361 | + # Replace function body |
| 362 | + modeldef[:body] = model_info[:main_body] |
| 363 | + |
| 364 | + ## Extract other relevant information |
| 365 | + |
| 366 | + # All arguments with default values (if existent) |
322 | 367 | args = model_info[:args]
|
323 |
| - # Argument symbols without default values |
324 |
| - arg_syms = model_info[:arg_syms] |
325 |
| - # Arguments namedtuple |
| 368 | + # Named tuple of all arguments |
326 | 369 | args_nt = model_info[:args_nt]
|
327 |
| - # Default values of the arguments |
328 |
| - # Arguments namedtuple |
| 370 | + |
| 371 | + # Named tuple of the default values of the arguments |
329 | 372 | defaults_nt = model_info[:defaults_nt]
|
330 |
| - # Where parameters |
331 |
| - whereparams = model_info[:whereparams] |
332 |
| - # Model generator name |
| 373 | + |
| 374 | + # Model name |
333 | 375 | model = model_info[:name]
|
334 |
| - # Main body of the model |
335 |
| - main_body = model_info[:main_body] |
336 | 376 |
|
337 |
| - unwrap_data_expr = Expr(:block) |
338 |
| - for var in arg_syms |
339 |
| - push!(unwrap_data_expr.args, |
340 |
| - :($var = $(DynamicPPL.matchingvalue)(_sampler, _varinfo, _model.args.$var))) |
| 377 | + # Define model definition with only keyword arguments |
| 378 | + if isempty(args) |
| 379 | + model_kwform = () |
| 380 | + else |
| 381 | + # All arguments without default values (i.e., only symbols) |
| 382 | + arg_syms = model_info[:arg_syms] |
| 383 | + |
| 384 | + model_kwform = (:($model(; $(args...)) = $model($(arg_syms...))),) |
341 | 385 | end
|
342 | 386 |
|
343 |
| - model_kwform = isempty(args) ? () : (:($model(;$(args...)) = $model($(arg_syms...))),) |
344 | 387 | @gensym(evaluator)
|
345 |
| - |
346 | 388 | return quote
|
347 | 389 | $(Base).@__doc__ function $model($(args...))
|
348 |
| - $evaluator = let |
349 |
| - function ( |
350 |
| - _rng::$(Random.AbstractRNG), |
351 |
| - _model::$(DynamicPPL.Model), |
352 |
| - _varinfo::$(DynamicPPL.AbstractVarInfo), |
353 |
| - _sampler::$(DynamicPPL.AbstractSampler), |
354 |
| - _context::$(DynamicPPL.AbstractContext), |
355 |
| - ) |
356 |
| - $unwrap_data_expr |
357 |
| - $main_body |
358 |
| - end |
359 |
| - end |
| 390 | + $evaluator = $(ExprTools.combinedef(modeldef)) |
360 | 391 | return $(DynamicPPL.Model)($evaluator, $args_nt, $defaults_nt)
|
361 | 392 | end
|
362 | 393 | $(model_kwform...)
|
|
0 commit comments