diff --git a/docs/src/julia_syntax.md b/docs/src/julia_syntax.md new file mode 100644 index 000000000..1fb2ecf98 --- /dev/null +++ b/docs/src/julia_syntax.md @@ -0,0 +1,187 @@ +# How to Specify and Create a `BUGSModel` + +Creating a `BUGSModel` requires two key components: a BUGS program that defines the model structure and values for specific variables that parameterize the model. + +To understand how to specify a model properly, it is important to distinguish between the different types of values you can provide to the JuliaBUGS compiler: + +* **Constants**: Values used in loop bounds and index resolution + * These are essential for model specification as they determine the model's dimensionality (how many variables are created) and establish the dependency structure between variables + +* **Independent variables** (also called features, predictors, or covariates): Non-stochastic inputs required for forward simulation of the model + * Examples include predictor variables in a regression model or time points in a time series model + +* **Observations**: Values for stochastic variables that you wish to condition on + * These are not necessary to specify the model structure, but when provided, they become the data that your model is conditioned on + * (Note: In some advanced cases, stochastic variables can contribute to the log density without being part of a strictly generative model) + +* **Initialization values**: Starting points for MCMC sampling + * While optional in many cases, some models (particularly those with weakly informative priors or complex structures) require carefully chosen initialization values for effective sampling + +## Syntax from previous BUGS softwares and their R packages + +Traditionally, BUGS models were created through a software interface following these steps: +1. Write the model in a text file +2. Check the model syntax (parsing) +3. Compile the model with program text and data +4. Initialize the sampling process (optional) + +R interface packages for BUGS maintained this workflow pattern through text-based interfaces that closely mirrored the original software. + +JuliaBUGS initially adopted this familiar workflow to accommodate users with prior BUGS experience. Specifically, JuliaBUGS provides a `@bugs` macro that accepts model definitions either as strings or within a `begin...end` block: + +```julia +# Example using string macro +@bugs""" +model { + for( i in 1 : N ) { + r[i] ~ dbin(p[i],n[i]) + b[i] ~ dnorm(0.0,tau) + logit(p[i]) <- alpha0 + alpha1 * x1[i] + alpha2 * x2[i] + + alpha12 * x1[i] * x2[i] + b[i] + } + alpha0 ~ dnorm(0.0,1.0E-6) + alpha1 ~ dnorm(0.0,1.0E-6) + alpha2 ~ dnorm(0.0,1.0E-6) + alpha12 ~ dnorm(0.0,1.0E-6) + tau ~ dgamma(0.001,0.001) + sigma <- 1 / sqrt(tau) +} +""" + +# Example using block macro +@bugs begin + for i in 1:N + r[i] ~ dbin(p[i], n[i]) + b[i] ~ dnorm(0.0, tau) + p[i] = logistic(alpha0 + alpha1 * x1[i] + alpha2 * x2[i] + alpha12 * x1[i] * x2[i] + + b[i]) + end + alpha0 ~ dnorm(0.0, 1.0e-6) + alpha1 ~ dnorm(0.0, 1.0e-6) + alpha2 ~ dnorm(0.0, 1.0e-6) + alpha12 ~ dnorm(0.0, 1.0e-6) + tau ~ dgamma(0.001, 0.001) + sigma = 1 / sqrt(tau) +end +``` + +In both cases, the macro returns a Julia AST representation of the model. The `compile` function then takes this AST and user-provided values (as a `NamedTuple`) to create a `BUGSModel` instance. + +While we maintain this interface for compatibility, we now also offer a more idiomatic Julia approach. + +## The Interface + +JuliaBUGS provides a Julian interface inspired by Turing.jl's model macro syntax. The `@model` macro creates a "model creating function" that returns a model object supporting operations like `AbstractMCMC.sample` (which samples MCMC chains) and `condition` (which modifies the model by incorporating observations). + +### The `@model` Macro + +```julia +JuliaBUGS.@model function model_definition((;r, b, alpha0, alpha1, alpha2, alpha12, tau)::SeedsParams, x1, x2, N, n) + for i in 1:N + r[i] ~ dbin(p[i], n[i]) + b[i] ~ dnorm(0.0, tau) + p[i] = logistic(alpha0 + alpha1 * x1[i] + alpha2 * x2[i] + alpha12 * x1[i] * x2[i] + b[i]) + end + alpha0 ~ dnorm(0.0, 1.0E-6) + alpha1 ~ dnorm(0.0, 1.0E-6) + alpha2 ~ dnorm(0.0, 1.0E-6) + alpha12 ~ dnorm(0.0, 1.0E-6) + tau ~ dgamma(0.001, 0.001) + sigma = 1 / sqrt(tau) +end +``` + +The `@model` macro requires a specific function signature: + +1. The first argument must declare stochastic parameters (variables defined with `~`) using destructuring assignment with the format `(; param1, param2, ...)`. +2. We recommend providing a type annotation (e.g., `(; r, b, ...)::SeedsParams`). If `SeedsParams` is defined using `@parameters`, the macro automatically defines a constructor `SeedsParams(model::BUGSModel)` for extracting parameter values from the model. +3. Alternatively, you can use a `NamedTuple` instead of a custom type. In this case, no type annotation is needed, but you would need to manually create a `NamedTuple` with `ParameterPlaceholder()` values or arrays of `missing` values for parameters that don't have observations. +4. The remaining arguments must specify all constants and independent variables required by the model (variables used on the RHS but not on the LHS). + +The `@parameters` macro simplifies creating structs to hold model parameters: + +```julia +JuliaBUGS.@parameters struct SeedsParams + r + b + alpha0 + alpha1 + alpha2 + alpha12 + tau +end +``` + +This macro applies `Base.@kwdef` to enable keyword initialization and creates a no-argument constructor. By default, fields are initialized to `JuliaBUGS.ParameterPlaceholder`. The concrete types and sizes of parameters are determined during compilation when the model function is called with constants. A constructor `SeedsParams(::BUGSModel)` is created for easy extraction of parameter values. + +### Example + +```julia +julia> @model function seeds( + (; r, b, alpha0, alpha1, alpha2, alpha12, tau)::SeedsParams, x1, x2, N, n + ) + for i in 1:N + r[i] ~ dbin(p[i], n[i]) + b[i] ~ dnorm(0.0, tau) + p[i] = logistic( + alpha0 + alpha1 * x1[i] + alpha2 * x2[i] + alpha12 * x1[i] * x2[i] + b[i] + ) + end + alpha0 ~ dnorm(0.0, 1.0E-6) + alpha1 ~ dnorm(0.0, 1.0E-6) + alpha2 ~ dnorm(0.0, 1.0E-6) + alpha12 ~ dnorm(0.0, 1.0E-6) + tau ~ dgamma(0.001, 0.001) + sigma = 1 / sqrt(tau) + end +seeds (generic function with 1 method) + +julia> (; x1, x2, N, n) = JuliaBUGS.BUGSExamples.seeds.data; # extract data from existing BUGS example + +julia> @parameters struct SeedsParams + r + b + alpha0 + alpha1 + alpha2 + alpha12 + tau + end + +julia> m = seeds(SeedsParams(), x1, x2, N, n) +BUGSModel (parameters are in transformed (unconstrained) space, with dimension 47): + + Model parameters: + alpha2 + b[21], b[20], b[19], b[18], b[17], b[16], b[15], b[14], b[13], b[12], b[11], b[10], b[9], b[8], b[7], b[6], b[5], b[4], b[3], b[2], b[1] + r[21], r[20], r[19], r[18], r[17], r[16], r[15], r[14], r[13], r[12], r[11], r[10], r[9], r[8], r[7], r[6], r[5], r[4], r[3], r[2], r[1] + tau + alpha12 + alpha1 + alpha0 + + Variable sizes and types: + b: size = (21,), type = Vector{Float64} + p: size = (21,), type = Vector{Float64} + n: size = (21,), type = Vector{Int64} + alpha2: type = Float64 + sigma: type = Float64 + alpha12: type = Float64 + alpha0: type = Float64 + N: type = Int64 + tau: type = Float64 + alpha1: type = Float64 + r: size = (21,), type = Vector{Float64} + x1: size = (21,), type = Vector{Int64} + x2: size = (21,), type = Vector{Int64} + +julia> SeedsParams(m) +SeedsParams: + r = [0.0, 0.0, 0.0, 0.0, 39.0, 0.0, 0.0, 72.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 4.0, 12.0, 0.0, 0.0, 0.0, 0.0] + b = [-Inf, -Inf, -Inf, -Inf, Inf, -Inf, -Inf, Inf, -Inf, -Inf … -Inf, -Inf, -Inf, -Inf, Inf, Inf, -Inf, -Inf, -Inf, -Inf] + alpha0 = -1423.52 + alpha1 = 1981.99 + alpha2 = -545.664 + alpha12 = 1338.25 + tau = 0.0 +``` diff --git a/src/JuliaBUGS.jl b/src/JuliaBUGS.jl index 36d5a5cb1..b9b27f31f 100644 --- a/src/JuliaBUGS.jl +++ b/src/JuliaBUGS.jl @@ -43,20 +43,30 @@ include("source_gen.jl") include("BUGSExamples/BUGSExamples.jl") function check_input(input::NamedTuple) + valid_pairs = Pair{Symbol,Any}[] for (k, v) in pairs(input) - if v isa AbstractArray - if !(eltype(v) <: Union{Int,Float64,Missing}) + if v === missing + continue # Skip missing values + elseif v isa AbstractArray + # Allow arrays containing Int, Float64, or Missing + allowed_eltypes = Union{Int,Float64,Missing} + if !(eltype(v) <: allowed_eltypes) error( - "For array input, only Int, Float64, or Missing types are supported. Received: $(typeof(v)).", + "For array input '$k', only elements of type $allowed_eltypes are supported. Received array with eltype: $(eltype(v)).", ) end - elseif v === missing - error("Scalars cannot be missing. Received: $k") - elseif !(v isa Union{Int,Float64}) - error("Scalars must be of type Int or Float64. Received: $k") + push!(valid_pairs, k => v) + elseif v isa Union{Int,Float64} + # Allow scalar Int or Float64 + push!(valid_pairs, k => v) + else + # Error for other scalar types + error( + "Scalar input '$k' must be of type Int or Float64. Received: $(typeof(v))." + ) end end - return input + return NamedTuple(valid_pairs) end function check_input(input::Dict{KT,VT}) where {KT,VT} if isempty(input) @@ -177,6 +187,16 @@ function compile(model_def::Expr, data::NamedTuple, initial_params::NamedTuple=N ) return BUGSModel(g, nonmissing_eval_env, model_def, data, initial_params) end +# function compile( +# model_str::String, +# data::NamedTuple, +# initial_params::NamedTuple=NamedTuple(); +# replace_period::Bool=true, +# no_enclosure::Bool=false, +# ) +# model_def = _bugs_string_input(model_str, replace_period, no_enclosure) +# return compile(model_def, data, initial_params) +# end """ @register_primitive(expr) @@ -253,6 +273,8 @@ Only defined with `MCMCChains` extension. """ function gen_chains end +include("model_macro.jl") + include("experimental/ProbabilisticGraphicalModels/ProbabilisticGraphicalModels.jl") end diff --git a/src/model_macro.jl b/src/model_macro.jl new file mode 100644 index 000000000..0c18d9e69 --- /dev/null +++ b/src/model_macro.jl @@ -0,0 +1,259 @@ +using MacroTools + +# The `@capture` macro from MacroTools is used to pattern-match Julia code. +# When a variable in the pattern is followed by a single underscore (e.g., `var_`), +# it captures a single component of the Julia expression and binds it locally to that +# variable name. If a variable is followed by double underscores (e.g., `vars__`), +# it captures multiple components into an array. + +struct ParameterPlaceholder end + +macro parameters(struct_expr) + if MacroTools.@capture(struct_expr, struct struct_name_ + struct_fields__ + end) + return _generate_struct_definition( + struct_name, struct_fields, __source__, __module__ + ) + else + return :(throw( + ArgumentError( + "Expected a struct definition like '@parameters struct MyParams ... end'" + ), + )) + end +end + +function _generate_struct_definition(struct_name, struct_fields, __source__, __module__) + if !isa(struct_name, Symbol) + return :(throw( + ArgumentError( + "Parametrized types (e.g., `struct MyParams{T}`) are not supported yet" + ), + )) + end + + if !all(isa.(struct_fields, Symbol)) + return :(throw( + ArgumentError( + "Field types are determined by JuliaBUGS automatically. Specify types for fields is not allowed for now.", + ), + )) + end + + show_method_expr = MacroTools.@q function Base.show( + io::IO, mime::MIME"text/plain", params::$(esc(struct_name)) + ) + # Use IOContext for potentially compact/limited printing of field values + ioc = IOContext(io, :compact => true, :limit => true) + + println(ioc, "$(nameof(typeof(params))):") + fields = fieldnames(typeof(params)) + + # Handle empty structs gracefully + if isempty(fields) + print(ioc, " (no fields)") + return nothing + end + + # Calculate maximum field name length for alignment + max_len = maximum(length ∘ string, fields) + for field in fields + value = getfield(params, field) + field_str = rpad(string(field), max_len) + print(ioc, " ", field_str, " = ") + if value isa JuliaBUGS.ParameterPlaceholder + # Use the IOContext here as well + printstyled(ioc, ""; color=:light_black) + else + # Capture the string representation using the context + # Use the basic `show` for a more compact representation, especially for arrays + str_representation = sprint(show, value; context=ioc) + # Print the captured string with color + printstyled(ioc, str_representation; color=:cyan) + end + # Use the IOContext for the newline too + println(ioc) + end + end + + kw_assignments = map(f -> Expr(:kw, esc(f), :(ParameterPlaceholder())), struct_fields) + kwarg_constructor_expr = MacroTools.@q function $(esc(struct_name))(; + $(kw_assignments...) + ) + return $(esc(struct_name))($(map(esc, struct_fields)...)) + end + return MacroTools.@q begin + begin + struct $(esc(struct_name)) + $(map(esc, struct_fields)...) + end + $(kwarg_constructor_expr) + end + + $(show_method_expr) + + function $(esc(struct_name))(model::BUGSModel) + return getparams($(esc(struct_name)), model) + end + end +end + +macro model(model_function_expr) + return _generate_model_definition(model_function_expr, __source__, __module__) +end + +function _generate_model_definition(model_function_expr, __source__, __module__) + MacroTools.@capture( + #! format: off + model_function_expr, + function model_name_(param_destructure_, constant_variables__) + body_expr__ + end + #! format: on + ) || return :(throw(ArgumentError("Expected a model function definition"))) + + model_def = _add_line_number_nodes(Expr(:block, body_expr...)) # hack, see _add_line_number_nodes + Parser.warn_cumulative_density_deviance(model_def) # refer to parser/bugs_macro.jl + + bugs_ast = Parser.bugs_top(model_def, __source__) + + param_type = nothing + MacroTools.@capture( + param_destructure, (((; param_fields__)::param_type_) | ((; param_fields__))) + ) || return :(throw( + ArgumentError( + "The first argument of the model function must be a destructuring assignment with a type annotation defined using `@parameters`.", + ), + )) + + illegal_constant_variables = Any[] + constant_variables_symbols = map(constant_variables) do constant_variable + if constant_variable isa Symbol + return constant_variable + elseif MacroTools.@capture( + constant_variable, ((name_ = default_value_) | (name_::type_)) + ) + return name_ + else + push!(illegal_constant_variables, constant_variable) + end + end + if !isempty(illegal_constant_variables) + formatted_vars = join(illegal_constant_variables, ", ", " and ") + return MacroTools.@q error( + string( + "The following arguments are not supported syntax for the model function currently: ", + $(QuoteNode(formatted_vars)), + "Please report this issue at https://github.com/TuringLang/JuliaBUGS.jl/issues", + ), + ) + end + + vars_and_numdims = extract_variable_names_and_numdims(bugs_ast) + vars_assigned_to = extract_variables_assigned_to(bugs_ast) + stochastic_vars = [vars_assigned_to[2]..., vars_assigned_to[4]...] + deterministic_vars = [vars_assigned_to[1]..., vars_assigned_to[3]...] + all_vars = collect(keys(vars_and_numdims)) + constants = setdiff(all_vars, vcat(stochastic_vars, deterministic_vars)) + + # Check if all constants used in the model are included in function arguments + if !all(in(constant_variables), constants) + missing_constants = setdiff(constants, constant_variables) + formatted_vars = join(missing_constants, ", ", " and ") + return MacroTools.@q error( + string( + "The following constants used in the model are not included in the function arguments: ", + $(QuoteNode(formatted_vars)), + ), + ) + end + + # Check if all stochastic variables are included in the parameters struct + missing_stochastic_vars = setdiff(stochastic_vars, param_fields) + if !isempty(missing_stochastic_vars) + formatted_vars = join(missing_stochastic_vars, ", ", " and ") + return MacroTools.@q error( + string( + "The following stochastic variables used in the model are not included in the parameters ", + "in the first argument of the model function: ", + $(QuoteNode(formatted_vars)), + ), + ) + end + + func_expr = MacroTools.@q function ($(esc(model_name)))( + params_struct, $(esc.(constant_variables)...) + ) + (; $(esc.(param_fields)...)) = params_struct + data = _param_struct_to_NT((; + $([esc.(param_fields)..., esc.(constant_variables)...]...) + )) + model_def = $(QuoteNode(bugs_ast)) + return compile(model_def, data) + end + + if param_type === nothing + return func_expr + else + return MacroTools.@q begin + function JuliaBUGS.getparams($(esc(param_type)), model::BUGSModel) + env = model.evaluation_env + field_names = fieldnames($(esc(param_type))) + kwargs = Dict{Symbol,Any}() + + for field in field_names + if haskey(env, field) + kwargs[field] = env[field] + end + end + + return $(esc(param_type))(; kwargs...) + end + $func_expr + end + end +end + +function _param_struct_to_NT(param_struct) + field_names = fieldnames(typeof(param_struct)) + pairs = Pair{Symbol,Any}[] + + for field_name in field_names + value = getfield(param_struct, field_name) + if !(value isa ParameterPlaceholder) + push!(pairs, field_name => value) + end + end + + return NamedTuple(pairs) +end + +# This function addresses a discrepancy in how Julia's parser handles LineNumberNode insertion. +# When parsing a function body, the parser only adds a LineNumberNode before the first statement. +# In contrast, when parsing a "begin ... end" block, it inserts a LineNumberNode before each statement. +# The `bugs_top` function assumes input comes from a macro and expects a LineNumberNode before each statement. +# As a workaround, this function ensures that a LineNumberNode precedes every statement in the model function's body. +function _add_line_number_nodes(expr) + if !(expr isa Expr) + return expr + end + + if Meta.isexpr(expr, :block) + new_args = [] + + for arg in expr.args + if !(arg isa LineNumberNode) && + (isempty(new_args) || !(new_args[end] isa LineNumberNode)) + push!(new_args, LineNumberNode(0, :none)) # use a dummy LineNumberNode + end + + push!(new_args, arg isa Expr ? _add_line_number_nodes(arg) : arg) + end + + return Expr(:block, new_args...) + else + new_args = map(arg -> _add_line_number_nodes(arg), expr.args) + return Expr(expr.head, new_args...) + end +end diff --git a/test/model_macro.jl b/test/model_macro.jl new file mode 100644 index 000000000..0de6eccee --- /dev/null +++ b/test/model_macro.jl @@ -0,0 +1,110 @@ +using JuliaBUGS +using JuliaBUGS: @parameters, @model + +@testset "model macro" begin + @parameters struct Tp + r + b + alpha0 + alpha1 + alpha2 + alpha12 + tau + end + + #! format: off + @model function seeds( + (; r, b, alpha0, alpha1, alpha2, alpha12, tau)::Tp, x1, x2, N, n + ) + for i in 1:N + r[i] ~ dbin(p[i], n[i]) + b[i] ~ dnorm(0.0, tau) + p[i] = logistic( + alpha0 + alpha1 * x1[i] + alpha2 * x2[i] + alpha12 * x1[i] * x2[i] + b[i] + ) + end + alpha0 ~ dnorm(0.0, 1.0E-6) + alpha1 ~ dnorm(0.0, 1.0E-6) + alpha2 ~ dnorm(0.0, 1.0E-6) + alpha12 ~ dnorm(0.0, 1.0E-6) + tau ~ dgamma(0.001, 0.001) + sigma = 1 / sqrt(tau) + end + #! format: on + + # Try destructuring the random variables but forgetting to include one (tau). + @test_throws ErrorException begin + #! format: off + @model function seeds( + # tau is missing + (; r, b, alpha0, alpha1, alpha2, alpha12)::Tp, x1, x2, N, n + ) + for i in 1:N + r[i] ~ dbin(p[i], n[i]) + b[i] ~ dnorm(0.0, tau) + p[i] = logistic( + alpha0 + alpha1 * x1[i] + alpha2 * x2[i] + alpha12 * x1[i] * x2[i] + b[i] + ) + end + alpha0 ~ dnorm(0.0, 1.0E-6) + alpha1 ~ dnorm(0.0, 1.0E-6) + alpha2 ~ dnorm(0.0, 1.0E-6) + alpha12 ~ dnorm(0.0, 1.0E-6) + tau ~ dgamma(0.001, 0.001) + sigma = 1 / sqrt(tau) + end + #! format: on + end + + # Try leaving out one constant variable. + @test_throws ErrorException begin + #! format: off + @model function seeds( + # x1 is missing + (; r, b, alpha0, alpha1, alpha2, alpha12, tau)::Tp, x2, N, n + ) + for i in 1:N + r[i] ~ dbin(p[i], n[i]) + b[i] ~ dnorm(0.0, tau) + p[i] = logistic( + alpha0 + alpha1 * x1[i] + alpha2 * x2[i] + alpha12 * x1[i] * x2[i] + b[i] + ) + end + alpha0 ~ dnorm(0.0, 1.0E-6) + alpha1 ~ dnorm(0.0, 1.0E-6) + alpha2 ~ dnorm(0.0, 1.0E-6) + alpha12 ~ dnorm(0.0, 1.0E-6) + tau ~ dgamma(0.001, 0.001) + sigma = 1 / sqrt(tau) + end + #! format: on + end + + data = JuliaBUGS.BUGSExamples.seeds.data + m = seeds(Tp(), data.x1, data.x2, data.N, data.n) + + # use NamedTuple to pass parameters + # with missing values + N = data.N + params_nt = ( + r=fill(missing, N), + b=fill(missing, N), + alpha0=missing, + alpha1=missing, + alpha2=missing, + alpha12=missing, + tau=missing, + ) + m = seeds(params_nt, data.x1, data.x2, data.N, data.n) + + params_nt_with_data = ( + r=data.r, + b=JuliaBUGS.ParameterPlaceholder(), + alpha0=JuliaBUGS.ParameterPlaceholder(), + alpha1=JuliaBUGS.ParameterPlaceholder(), + alpha2=JuliaBUGS.ParameterPlaceholder(), + alpha12=JuliaBUGS.ParameterPlaceholder(), + tau=JuliaBUGS.ParameterPlaceholder(), + ) + m = seeds(params_nt_with_data, data.x1, data.x2, data.N, data.n) +end diff --git a/test/runtests.jl b/test/runtests.jl index 97b841e7b..57c241078 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -54,6 +54,7 @@ if test_group == "elementary" || test_group == "all" include("parser/test_parser.jl") include("passes.jl") include("graphs.jl") + include("model_macro.jl") end if test_group == "compilation" || test_group == "all"