-
Notifications
You must be signed in to change notification settings - Fork 10
Description
This is a slight modification of the basic example in the docs:
using Libtask
@noinline function g(t; n="a", p=3.0) # prevent the compiler being too smart
produce(t)
end
function f()
for t in 1:2
g(t; n="b", p=2.0)
t += 1
end
return nothing
end
t = TapedTask(nothing, f)
consume(t)
I've been mucking around with trying to get this to work (in the interests of making Turing models with kwargs work with SMC/PG). TuringLang/Turing.jl#2007 Right now with [email protected], these are the might_produce
overloads that I need:
Libtask.might_produce(::Type{<:Tuple{typeof(Core.kwcall),<:NamedTuple,typeof(g),Vararg}}) = true
Libtask.might_produce(::Type{<:Tuple{<:Function,String,Float64,typeof(g),Vararg}}) = true
What this means for Turing
The first overload (with kwcall
) seems easy enough to handle from a Turing point of view. The second one is more annoying though. It arises because Julia does this fancy thing of creating a new function and if you print x.args[1]
here
Libtask.jl/src/copyable_task.jl
Line 403 in f154425
Meta.isexpr(x, :invoke) && return might_produce(x.args[1].specTypes) |
you get
x.args[1] = MethodInstance for var"#g#3"(::String, ::Float64, ::typeof(g), ::Int64)
where I think it gensyms a new symbol each time it needs to (e.g. if the function is redefined).
Unfortunately this means that it's quite hard to statically specify the correct might_produce
call as we need to know the types of the keyword arguments. (Thankfully, the order of the keyword arguments is only as defined in the original function; it doesn't matter which way round they're given in the invocation.) We unfortunately also can't use Vararg
since it can only appear as the last type parameter. We could use a macro to generate it.
On the plus side though it is at least possible although it will require extra work from the user! The scary thing though is that if you aren't careful and don't define the right might_produce
methods it will silently fail.
Proof of principle of it working with Turing (nb. it won't work with main: you'll need to use TuringLang/Turing.jl#2660 + TuringLang/AdvancedPS.jl#118)
using Turing, Libtask
@model function m(y; n=0)
x ~ Normal(n)
y ~ Normal(x)
end
Libtask.might_produce(::Type{<:Tuple{typeof(Core.kwcall),<:NamedTuple,typeof(m),Vararg}}) = true
Libtask.might_produce(::Type{<:Tuple{<:Function,<:Real,typeof(m),Vararg}}) = true
mean(sample(m(5.0), PG(20), 1000)) # approx 2.5
mean(sample(m(5.0; n=10.0), PG(20), 1000)) # approx 7.5