Skip to content

[Merged by Bors] - Remove ModelGen #134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
.DS_Store
/Manifest.toml
/dev/
/test/gdemo_default.jls
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ julia = "1"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2"
Expand All @@ -38,6 +39,7 @@ ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand All @@ -48,4 +50,4 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs", "Zygote"]
test = ["AdvancedHMC", "AdvancedMH", "Distributed", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "Serialization", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs", "Zygote"]
7 changes: 1 addition & 6 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ module DynamicPPL
using AbstractMCMC: AbstractSampler, AbstractChains, AbstractModel
using Distributions
using Bijectors
using MacroTools

import AbstractMCMC
import MacroTools
import ZygoteRules

import Random
Expand Down Expand Up @@ -51,25 +51,20 @@ export AbstractVarInfo,
inspace,
subsumes,
# Compiler
ModelGen,
@model,
@varname,
# Utilities
vectorize,
reconstruct,
reconstruct!,
Sample,
Chain,
init,
vectorize,
set_resume!,
# Model
ModelGen,
Model,
getmissings,
getargnames,
getdefaults,
getgenerator,
# Samplers
Sampler,
SampleFromPrior,
Expand Down
209 changes: 98 additions & 111 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,17 @@ Macro to specify a probabilistic model.
If `warn` is `true`, a warning is displayed if internal variable names are used in the model
definition.

# Example
# Examples

Model definition:

```julia
@model function model_generator(x = default_x, y)
@model function model(x, y = 42)
...
end
```

To generate a `Model`, call `model_generator(x_value)`.
To generate a `Model`, call `model(xvalue)` or `model(xvalue, yvalue)`.
"""
macro model(expr, warn=true)
esc(model(expr, warn))
Expand All @@ -69,7 +69,9 @@ function model(expr, warn)
modelinfo = build_model_info(expr)

# Generate main body
modelinfo[:modelbody] = generate_mainbody(modelinfo[:body], modelinfo[:modelargs], warn)
modelinfo[:body] = generate_mainbody(
modelinfo[:modeldef][:body], modelinfo[:allargs_exprs], warn
)

return build_output(modelinfo)
end
Expand All @@ -80,87 +82,74 @@ end
Builds the `model_info` dictionary from the model's expression.
"""
function build_model_info(input_expr)
# Extract model name (:name), arguments (:args), (:kwargs) and definition (:body)
# Break up the model definition and extract its name, arguments, and function body
modeldef = MacroTools.splitdef(input_expr)
# Function body of the model is empty

# Print a warning if function body of the model is empty
warn_empty(modeldef[:body])
# Construct model_info dictionary

# Extracting the argument symbols from the model definition
combinedargs = vcat(modeldef[:args], modeldef[:kwargs])
arg_syms = map(combinedargs) do arg
# @model demo(x)
if (arg isa Symbol)
arg
# @model demo(::Type{T}) where {T}
elseif MacroTools.@capture(arg, ::Type{T_} = Tval_)
T
# @model demo(x::T = 1)
elseif MacroTools.@capture(arg, x_::T_ = val_)
x
# @model demo(x = 1)
elseif MacroTools.@capture(arg, x_ = val_)
x
else
throw(ArgumentError("Unsupported argument $arg to the `@model` macro."))
end
end
if length(arg_syms) == 0
args_nt = :(NamedTuple())
else
nt_type = Expr(:curly, :NamedTuple,
Expr(:tuple, QuoteNode.(arg_syms)...),
Expr(:curly, :Tuple, [:(Core.Typeof($x)) for x in arg_syms]...)

## Construct model_info dictionary

# Shortcut if the model does not have any arguments
if !haskey(modeldef, :args) && !haskey(modeldef, :kwargs)
modelinfo = Dict(
:allargs_exprs => [],
:allargs_syms => [],
:allargs_namedtuple => NamedTuple(),
:defaults_namedtuple => NamedTuple(),
:modeldef => modeldef,
)
args_nt = Expr(:call, :($namedtuple), nt_type, Expr(:tuple, arg_syms...))
return modelinfo
end
args = map(combinedargs) do arg
if (arg isa Symbol)
arg
elseif MacroTools.@capture(arg, ::Type{T_} = Tval_)
if in(T, modeldef[:whereparams])
S = :Any
else
ind = findfirst(modeldef[:whereparams]) do x
MacroTools.@capture(x, T1_ <: S_) && T1 == T
end
ind !== nothing || throw(ArgumentError("Please make sure type parameters are properly used. Every `Type{T}` argument need to have `T` in the a `where` clause"))
end
Expr(:kw, :($T::Type{<:$S}), Tval)
else
arg

# Extract the positional and keyword arguments from the model definition.
allargs = vcat(modeldef[:args], modeldef[:kwargs])

# Split the argument expressions and the default values.
allargs_exprs_defaults = map(allargs) do arg
MacroTools.@match arg begin
(x_ = val_) => (x, val)
x_ => (x, NO_DEFAULT)
end
end

# Extract the expressions of the arguments, without default values.
allargs_exprs = first.(allargs_exprs_defaults)

# Extract the names of the arguments.
allargs_syms = map(allargs_exprs_defaults) do (arg, _)
MacroTools.@match arg begin
(::Type{T_}) | (name_::Type{T_}) => T
name_::T_ => name
x_ => x
end
end
args_nt = to_namedtuple_expr(arg_syms)

# Build named tuple expression of the argument symbols and variables of the same name.
allargs_namedtuple = to_namedtuple_expr(allargs_syms)

# Extract default values of the positional and keyword arguments.
default_syms = []
default_vals = []
foreach(combinedargs) do arg
# @model demo(::Type{T}) where {T}
if MacroTools.@capture(arg, ::Type{T_} = Tval_)
push!(default_syms, T)
push!(default_vals, Tval)
# @model demo(x::T = 1)
elseif MacroTools.@capture(arg, x_::T_ = val_)
push!(default_syms, x)
push!(default_vals, val)
# @model demo(x = 1)
elseif MacroTools.@capture(arg, x_ = val_)
push!(default_syms, x)
default_vals = []
for (sym, (expr, val)) in zip(allargs_syms, allargs_exprs_defaults)
if val !== NO_DEFAULT
push!(default_syms, sym)
push!(default_vals, val)
end
end
defaults_nt = to_namedtuple_expr(default_syms, default_vals)

modelderiv = Dict(
:modelargs => args,
:modelargsyms => arg_syms,
:modelargsnt => args_nt,
:modeldefaultsnt => defaults_nt,
# Build named tuple expression of the argument symbols with default values.
defaults_namedtuple = to_namedtuple_expr(default_syms, default_vals)

modelinfo = Dict(
:allargs_exprs => allargs_exprs,
:allargs_syms => allargs_syms,
:allargs_namedtuple => allargs_namedtuple,
:defaults_namedtuple => defaults_namedtuple,
:modeldef => modeldef,
)
model_info = merge(modeldef, modelderiv)

return model_info
return modelinfo
end

"""
Expand Down Expand Up @@ -312,55 +301,53 @@ hasmissing(T::Type{<:AbstractArray{>:Missing}}) = true
hasmissing(T::Type) = false

"""
build_output(model_info)
build_output(modelinfo)

Builds the output expression.
"""
function build_output(model_info)
# Arguments with default values
args = model_info[:modelargs]
# Argument symbols without default values
arg_syms = model_info[:modelargsyms]
# Arguments namedtuple
args_nt = model_info[:modelargsnt]
# Default values of the arguments
# Arguments namedtuple
defaults_nt = model_info[:modeldefaultsnt]
# Model generator name
model_gen = model_info[:name]
# Main body of the model
main_body = model_info[:modelbody]

unwrap_data_expr = Expr(:block)
for var in arg_syms
push!(unwrap_data_expr.args,
:($var = $(DynamicPPL.matchingvalue)(_sampler, _varinfo, _model.args.$var)))
end
function build_output(modelinfo)
## Build the anonymous evaluator from the user-provided model definition.

# Remove the name.
evaluatordef = deepcopy(modelinfo[:modeldef])
delete!(evaluatordef, :name)

# Add the internal arguments to the user-specified arguments (positional + keywords).
evaluatordef[:args] = vcat(
[
:(_rng::$(Random.AbstractRNG)),
:(_model::$(DynamicPPL.Model)),
:(_varinfo::$(DynamicPPL.AbstractVarInfo)),
:(_sampler::$(DynamicPPL.AbstractSampler)),
:(_context::$(DynamicPPL.AbstractContext)),
],
modelinfo[:allargs_exprs],
)

@gensym(evaluator, generator)
model_gen_constructor = :($(DynamicPPL.ModelGen){$(Tuple(arg_syms))}($generator, $defaults_nt))
# Delete the keyword arguments.
evaluatordef[:kwargs] = []

# construct the user-facing model generator
model_info[:name] = generator
model_info[:body] = :(return $(DynamicPPL.Model)($evaluator, $args_nt, $model_gen_constructor))
generator_expr = MacroTools.combinedef(model_info)
# Replace the user-provided function body with the version created by DynamicPPL.
evaluatordef[:body] = modelinfo[:body]

return quote
function $evaluator(
_rng::$(Random.AbstractRNG),
_model::$(DynamicPPL.Model),
_varinfo::$(DynamicPPL.AbstractVarInfo),
_sampler::$(DynamicPPL.AbstractSampler),
_context::$(DynamicPPL.AbstractContext),
)
$unwrap_data_expr
$main_body
end
## Build the model function.

$(generator_expr)
# Extract the named tuple expression of all arguments and the default values.
allargs_namedtuple = modelinfo[:allargs_namedtuple]
defaults_namedtuple = modelinfo[:defaults_namedtuple]

$(Base).@__doc__ $model_gen = $model_gen_constructor
# Update the function body of the user-specified model.
# We use a name for the anonymous evaluator that does not conflict with other variables.
modeldef = modelinfo[:modeldef]
@gensym evaluator
modeldef[:body] = quote
$evaluator = $(combinedef_anonymous(evaluatordef))
return $(DynamicPPL.Model)(
$evaluator, $allargs_namedtuple, $defaults_namedtuple
)
end

return :($(Base).@__doc__ $(MacroTools.combinedef(modeldef)))
end


Expand Down
Loading