Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
license = "MIT"
desc = "Tape based task copying in Turing"
repo = "https://github.com/TuringLang/Libtask.jl.git"
version = "0.9.4"
version = "0.9.5"

[deps]
MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4"
Expand Down
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ An opt-in mechanism marks functions that might contain `Libtask.produce` stateme

```@docs; canonical=true
Libtask.might_produce(::Type{<:Tuple})
Libtask.@might_produce
```
4 changes: 2 additions & 2 deletions perf/p0.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ end

# Case 1: Sample from the prior.
rng = MersenneTwister()
m = Turing.Core.TracedModel(gdemo(1.5, 2.0), SampleFromPrior(), VarInfo(), rng)
m = Turing.Inference.TracedModel(gdemo(1.5, 2.0), SampleFromPrior(), VarInfo(), rng)
f = m.evaluator[1];
args = m.evaluator[2:end];

Expand All @@ -27,7 +27,7 @@ println("Run a tape...")
@btime t.tf(args...)

# Case 2: SMC sampler
m = Turing.Core.TracedModel(gdemo(1.5, 2.0), Sampler(SMC(50)), VarInfo(), rng)
m = Turing.Inference.TracedModel(gdemo(1.5, 2.0), Sampler(SMC(50)), VarInfo(), rng)
f = m.evaluator[1];
args = m.evaluator[2:end];

Expand Down
2 changes: 1 addition & 1 deletion perf/p2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Random.seed!(rng, 2)
iterations = 500
model_fun = infiniteGMM(data)

m = Turing.Core.TracedModel(model_fun, Sampler(SMC(50)), VarInfo(), rng)
m = Turing.Inference.TracedModel(model_fun, Sampler(SMC(50)), VarInfo(), rng)
f = m.evaluator[1]
args = m.evaluator[2:end]

Expand Down
61 changes: 60 additions & 1 deletion src/copyable_task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,70 @@ end
`true` if a call to method with signature `sig` is permitted to contain
`Libtask.produce` statements.

This is an opt-in mechanism. the fallback method of this function returns `false` indicating
This is an opt-in mechanism. The fallback method of this function returns `false` indicating
that, by default, we assume that calls do not contain `Libtask.produce` statements.
"""
might_produce(::Type{<:Tuple}) = false

"""
@might_produce(f)

If `f` is a function that may call `Libtask.produce` inside it, then `@might_produce(f)`
will generate the appropriate methods needed to ensure that `Libtask.might_produce` returns
`true` for all relevant signatures of `f`. This works even if `f` has methods with keyword
arguments.

```jldoctest might_produce_macro
julia> # For this demonstration we need to mark `g` as not being inlineable.
@noinline function g(x; y, z=0)
produce(x + y + z)
end
g (generic function with 1 method)

julia> function f()
g(1; y=2, z=3)
end
f (generic function with 1 method)

julia> # This returns nothing because `g` isn't yet marked as being able to `produce`.
consume(Libtask.TapedTask(nothing, f))

julia> Libtask.@might_produce(g)

julia> # Now it works!
consume(Libtask.TapedTask(nothing, f))
6
"""
macro might_produce(f)
# See https://github.com/TuringLang/Libtask.jl/issues/197 for discussion of this macro.
quote
function $(Libtask).might_produce(::Type{<:Tuple{typeof($(esc(f))),Vararg}})
return true
Comment on lines +394 to +395
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is a little bit of a sledgehammer: we're basically saying, 'any invocation of f with any positional arguments might produce'. This is not necessarily true because some methods of f might produce and some might not.

But since there isn't any real downside to marking all methods are produceable, I don't think this is a huge issue. And if someone wants to be surgical, they can still use the non-macro version.

end
possible_n_kwargs = unique(map(length ∘ Base.kwarg_decl, methods($(esc(f)))))
if possible_n_kwargs != [0]
# Oddly we need to interpolate the module and not the function: either
# `$(might_produce)` or $(Libtask.might_produce) seem more natural but both of
# those cause the entire `Libtask.might_produce` to be treated as a single
# symbol. See https://discourse.julialang.org/t/128613
function $(Libtask).might_produce(
::Type{<:Tuple{typeof(Core.kwcall),<:NamedTuple,typeof($(esc(f))),Vararg}}
)
return true
end
for n in possible_n_kwargs
# We only need `Any` and not `<:Any` because tuples are covariant.
kwarg_types = fill(Any, n)
function $(Libtask).might_produce(
::Type{<:Tuple{<:Function,kwarg_types...,typeof($(esc(f))),Vararg}}
)
return true
end
end
end
end
end

# Helper struct used in `derive_copyable_task_ir`.
struct TupleRef
n::Int
Expand Down
49 changes: 49 additions & 0 deletions test/copyable_task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,4 +251,53 @@
@test Libtask.consume(tt) === :a
@test Libtask.consume(tt) === nothing
end

@testset "@might_produce macro" begin
# Positional arguments only
@noinline g1(x) = produce(x)
f1(x) = g1(x)
# Without marking it as might_produce
tt = Libtask.TapedTask(nothing, f1, 0)
@test Libtask.consume(tt) === nothing
# Now marking it
Libtask.@might_produce(g1)
tt = Libtask.TapedTask(nothing, f1, 0)
@test Libtask.consume(tt) === 0
@test Libtask.consume(tt) === nothing

# Keyword arguments only
@noinline g2(x; y=1, z=2) = produce(x + y + z)
f2(x) = g2(x)
# Without marking it as might_produce
tt = Libtask.TapedTask(nothing, f2, 0)
@test Libtask.consume(tt) === nothing
# Now marking it
Libtask.@might_produce(g2)
tt = Libtask.TapedTask(nothing, f2, 0)
@test Libtask.consume(tt) === 3
@test Libtask.consume(tt) === nothing

# A function with multiple methods.
# The function reference is used to ensure that it really doesn't get inlined
# (otherwise, for reasons that are yet unknown, these functions do get inlined when
# inside a testset)
@noinline g3(x) = produce(x)
@noinline g3(x, y; z) = produce(x + y + z)
@noinline g3(x, y, z; p, q) = produce(x + y + z + p + q)
function f3(x, fref)
fref[](x)
fref[](x, 1; z=2)
fref[](x, 1, 2; p=3, q=4)
return nothing
end
tt = Libtask.TapedTask(nothing, f3, 0, Ref(g3))
@test Libtask.consume(tt) === nothing
# Now marking it
Libtask.@might_produce(g3)
tt = Libtask.TapedTask(nothing, f3, 0, Ref(g3))
@test Libtask.consume(tt) === 0
@test Libtask.consume(tt) === 3
@test Libtask.consume(tt) === 10
@test Libtask.consume(tt) === nothing
end
end
Loading