Skip to content

Supporting functions with keyword arguments #197

@penelopeysm

Description

@penelopeysm

This is a slight modification of the basic example in the docs:

using Libtask
@noinline function g(t; n="a", p=3.0) # prevent the compiler being too smart
    produce(t)
end
function f()
    for t in 1:2
        g(t; n="b", p=2.0)
        t += 1
    end
    return nothing
end

t = TapedTask(nothing, f)
consume(t)

I've been mucking around with trying to get this to work (in the interests of making Turing models with kwargs work with SMC/PG). TuringLang/Turing.jl#2007 Right now with [email protected], these are the might_produce overloads that I need:

Libtask.might_produce(::Type{<:Tuple{typeof(Core.kwcall),<:NamedTuple,typeof(g),Vararg}}) = true
Libtask.might_produce(::Type{<:Tuple{<:Function,String,Float64,typeof(g),Vararg}}) = true

What this means for Turing

The first overload (with kwcall) seems easy enough to handle from a Turing point of view. The second one is more annoying though. It arises because Julia does this fancy thing of creating a new function and if you print x.args[1] here

Meta.isexpr(x, :invoke) && return might_produce(x.args[1].specTypes)

you get

x.args[1] = MethodInstance for var"#g#3"(::String, ::Float64, ::typeof(g), ::Int64)

where I think it gensyms a new symbol each time it needs to (e.g. if the function is redefined).

Unfortunately this means that it's quite hard to statically specify the correct might_produce call as we need to know the types of the keyword arguments. (Thankfully, the order of the keyword arguments is only as defined in the original function; it doesn't matter which way round they're given in the invocation.) We unfortunately also can't use Vararg since it can only appear as the last type parameter. We could use a macro to generate it.

On the plus side though it is at least possible although it will require extra work from the user! The scary thing though is that if you aren't careful and don't define the right might_produce methods it will silently fail.

Proof of principle of it working with Turing (nb. it won't work with main: you'll need to use TuringLang/Turing.jl#2660 + TuringLang/AdvancedPS.jl#118)

using Turing, Libtask

@model function m(y; n=0)
    x ~ Normal(n)
    y ~ Normal(x)
end
Libtask.might_produce(::Type{<:Tuple{typeof(Core.kwcall),<:NamedTuple,typeof(m),Vararg}}) = true
Libtask.might_produce(::Type{<:Tuple{<:Function,<:Real,typeof(m),Vararg}}) = true

mean(sample(m(5.0), PG(20), 1000))          # approx 2.5
mean(sample(m(5.0; n=10.0), PG(20), 1000))  # approx 7.5

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions