Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
license = "MIT"
desc = "Tape based task copying in Turing"
repo = "https://github.com/TuringLang/Libtask.jl.git"
version = "0.9.4"
version = "0.9.5"

[deps]
MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4"
Expand Down
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ An opt-in mechanism marks functions that might contain `Libtask.produce` stateme

```@docs; canonical=true
Libtask.might_produce(::Type{<:Tuple})
Libtask.@might_produce
```
61 changes: 60 additions & 1 deletion src/copyable_task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,70 @@ end
`true` if a call to method with signature `sig` is permitted to contain
`Libtask.produce` statements.

This is an opt-in mechanism. the fallback method of this function returns `false` indicating
This is an opt-in mechanism. The fallback method of this function returns `false` indicating
that, by default, we assume that calls do not contain `Libtask.produce` statements.
"""
might_produce(::Type{<:Tuple}) = false

"""
@might_produce(f)

If `f` is a function that may call `Libtask.produce` inside it, then `@might_produce(f)`
will generate the appropriate methods needed to ensure that `Libtask.might_produce` returns
`true` for all relevant signatures of `f`. This works even if `f` has methods with keyword
arguments.

```jldoctest might_produce_macro
julia> # For this demonstration we need to mark `g` as not being inlineable.
@noinline function g(x; y, z=0)
produce(x + y + z)
end
g (generic function with 1 method)

julia> function f()
g(1; y=2, z=3)
end
f (generic function with 1 method)

julia> # This returns nothing because `g` isn't yet marked as being able to `produce`.
consume(Libtask.TapedTask(nothing, f))

julia> Libtask.@might_produce(g)

julia> # Now it works!
consume(Libtask.TapedTask(nothing, f))
6
"""
macro might_produce(f)
# See https://github.com/TuringLang/Libtask.jl/issues/197 for discussion of this macro.
quote
function $(Libtask).might_produce(::Type{<:Tuple{typeof($(esc(f))),Vararg}})
return true
Comment on lines +394 to +395
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is a little bit of a sledgehammer: we're basically saying, 'any invocation of f with any positional arguments might produce'. This is not necessarily true because some methods of f might produce and some might not.

But since there isn't any real downside to marking all methods are produceable, I don't think this is a huge issue. And if someone wants to be surgical, they can still use the non-macro version.

end
possible_n_kwargs = unique(map(length ∘ Base.kwarg_decl, methods($(esc(f)))))
if possible_n_kwargs != [0]
# Oddly we need to interpolate the module and not the function: either
# `$(might_produce)` or $(Libtask.might_produce) seem more natural but both of
# those cause the entire `Libtask.might_produce` to be treated as a single
# symbol. See https://discourse.julialang.org/t/128613
function $(Libtask).might_produce(
::Type{<:Tuple{typeof(Core.kwcall),<:NamedTuple,typeof($(esc(f))),Vararg}}
)
return true
end
for n in possible_n_kwargs
# We only need `Any` and not `<:Any` because tuples are covariant.
kwarg_types = fill(Any, n)
function $(Libtask).might_produce(
::Type{<:Tuple{<:Function,kwarg_types...,typeof($(esc(f))),Vararg}}
)
return true
end
end
end
end
end

# Helper struct used in `derive_copyable_task_ir`.
struct TupleRef
n::Int
Expand Down
54 changes: 54 additions & 0 deletions test/copyable_task.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
module Functions
using Libtask: produce
@noinline g3(x) = produce(x)
@noinline g3(x, y; z) = produce(x + y + z)
@noinline g3(x, y, z; p, q) = produce(x + y + z + p + q)
function f3(x)
g3(x)
g3(x, 1; z=2)
g3(x, 1, 2; p=3, q=4)
end
end

@testset "copyable_task" begin
for case in Libtask.TestUtils.test_cases()
case()
Expand Down Expand Up @@ -251,4 +263,46 @@
@test Libtask.consume(tt) === :a
@test Libtask.consume(tt) === nothing
end

@testset "@might_produce macro" begin
# Positional arguments only
@noinline g1(x) = produce(x)
f1(x) = g1(x)
# Without marking it as might_produce
tt = Libtask.TapedTask(nothing, f1, 0)
@test Libtask.consume(tt) === nothing
# Now marking it
Libtask.@might_produce(g1)
tt = Libtask.TapedTask(nothing, f1, 0)
@test Libtask.consume(tt) === 0
@test Libtask.consume(tt) === nothing

# Keyword arguments only
@noinline g2(x; y=1, z=2) = produce(x + y + z)
f2(x) = g2(x)
# Without marking it as might_produce
tt = Libtask.TapedTask(nothing, f2, 0)
@test Libtask.consume(tt) === nothing
# Now marking it
Libtask.@might_produce(g2)
tt = Libtask.TapedTask(nothing, f2, 0)
@test Libtask.consume(tt) === 3
@test Libtask.consume(tt) === nothing

# A function with multiple methods.
# Note: f3 and g3 are defined in the module at the top of this file. If
# they are defined directly in this testset, for reasons that are
# unclear, the `produce` calls are picked up even without using the
# `@might_produce` macro, meaning that it's impossible to test that the
# macro is having the intended behaviour.
tt = Libtask.TapedTask(nothing, Functions.f3, 0)
@test Libtask.consume(tt) === nothing
# Now marking it
Libtask.@might_produce(Functions.g3)
tt = Libtask.TapedTask(nothing, Functions.f3, 0)
@test Libtask.consume(tt) === 0
@test Libtask.consume(tt) === 3
@test Libtask.consume(tt) === 10
@test Libtask.consume(tt) === nothing
end
end
Loading