diff --git a/.github/workflows/Testing.yaml b/.github/workflows/Testing.yaml index fc34a49d..36011f6a 100644 --- a/.github/workflows/Testing.yaml +++ b/.github/workflows/Testing.yaml @@ -11,9 +11,10 @@ jobs: strategy: matrix: version: - - '1.10' + - 'min' - '1' - - 'pre' + # TODO(mhauru) Reenable the below once there is a 'pre' version different from '1'. + # - 'pre' os: - ubuntu-latest - windows-latest diff --git a/HISTORY.md b/HISTORY.md new file mode 100644 index 00000000..73aff547 --- /dev/null +++ b/HISTORY.md @@ -0,0 +1,11 @@ +# 0.9.6 + +Add support for Julia v1.12. + +# 0.9.0 + +From version 0.9.0, the old `TArray` and `TRef` types are completely removed, where previously they were only deprecated. Additionally, the internals have been completely overhauled, and the public interface more precisely defined. See the docs for more info. + +# 0.6.0 + +From v0.6.0 Libtask is implemented by recording all the computing to a tape and copying that tape. Before that version, it is based on a tricky hack on the Julia internals. You can check the commit history of this repo to see the details. diff --git a/NEWS.md b/NEWS.md deleted file mode 100644 index db7b2b8b..00000000 --- a/NEWS.md +++ /dev/null @@ -1,8 +0,0 @@ -- From v0.6.0, Libtask is implemented by recording all the computing - to a tape and copying that tape. Before that version, it is based on - a tricky hack on the Julia internals. You can check the commit - history of this repo to see the details. - -- From version 0.9.0, the old `TArray` and `TRef` types are completely removed, where - previously they were only deprecated. Additionally, the internals have been completely - overhauled, and the public interface more precisely defined. See the docs for more info. diff --git a/Project.toml b/Project.toml index 744e97cd..4967752d 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" license = "MIT" desc = "Tape based task copying in Turing" repo = "https://github.com/TuringLang/Libtask.jl.git" -version = "0.9.5" +version = "0.9.6" [deps] MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4" diff --git a/src/bbcode.jl b/src/bbcode.jl index 2f86e269..bdffd55f 100644 --- a/src/bbcode.jl +++ b/src/bbcode.jl @@ -140,22 +140,44 @@ end collect_stmts(bb::BBlock)::Vector{IDInstPair} = collect(zip(bb.inst_ids, bb.insts)) -struct BBCode - blocks::Vector{BBlock} - argtypes::Vector{Any} - sptypes::Vector{CC.VarState} - linetable::Vector{Core.LineInfoNode} - meta::Vector{Expr} -end +@static if VERSION >= v"1.12-" + struct BBCode + blocks::Vector{BBlock} + argtypes::Vector{Any} + sptypes::Vector{CC.VarState} + debuginfo::CC.DebugInfoStream + meta::Vector{Expr} + valid_worlds::CC.WorldRange + end -function BBCode(ir::Union{IRCode,BBCode}, new_blocks::Vector{BBlock}) - return BBCode( - new_blocks, - CC.copy(ir.argtypes), - CC.copy(ir.sptypes), - CC.copy(ir.linetable), - CC.copy(ir.meta), - ) + function BBCode(ir::Union{IRCode,BBCode}, new_blocks::Vector{BBlock}) + return BBCode( + new_blocks, + CC.copy(ir.argtypes), + CC.copy(ir.sptypes), + CC.copy(ir.debuginfo), + CC.copy(ir.meta), + ir.valid_worlds, + ) + end +else + struct BBCode + blocks::Vector{BBlock} + argtypes::Vector{Any} + sptypes::Vector{CC.VarState} + linetable::Vector{Core.LineInfoNode} + meta::Vector{Expr} + end + + function BBCode(ir::Union{IRCode,BBCode}, new_blocks::Vector{BBlock}) + return BBCode( + new_blocks, + CC.copy(ir.argtypes), + CC.copy(ir.sptypes), + CC.copy(ir.linetable), + CC.copy(ir.meta), + ) + end end # Makes use of the above outer constructor for `BBCode`. @@ -352,20 +374,42 @@ function CC.IRCode(bb_code::BBCode) insts = _ids_to_line_numbers(bb_code) cfg = control_flow_graph(bb_code) insts = _lines_to_blocks(insts, cfg) - return IRCode( - CC.InstructionStream( - map(x -> x.stmt, insts), - map(x -> x.type, insts), - map(x -> x.info, insts), - map(x -> x.line, insts), - map(x -> x.flag, insts), - ), - cfg, - CC.copy(bb_code.linetable), - CC.copy(bb_code.argtypes), - CC.copy(bb_code.meta), - CC.copy(bb_code.sptypes), - ) + @static if VERSION >= v"1.12-" + # See e.g. here for how the NTuple{3,Int}s get flattened for InstructionStream: + # https://github.com/JuliaLang/julia/blob/16a2bf0a3b106b03dda23b8c9478aab90ffda5e1/Compiler/src/ssair/ir.jl#L299 + lines = map(x -> x.line, insts) + lines = collect(Iterators.flatten(lines)) + return IRCode( + CC.InstructionStream( + map(x -> x.stmt, insts), + collect(Any, map(x -> x.type, insts)), + collect(CC.CallInfo, map(x -> x.info, insts)), + lines, + map(x -> x.flag, insts), + ), + cfg, + CC.copy(bb_code.debuginfo), + CC.copy(bb_code.argtypes), + CC.copy(bb_code.meta), + CC.copy(bb_code.sptypes), + bb_code.valid_worlds, + ) + else + return IRCode( + CC.InstructionStream( + map(x -> x.stmt, insts), + map(x -> x.type, insts), + map(x -> x.info, insts), + map(x -> x.line, insts), + map(x -> x.flag, insts), + ), + cfg, + CC.copy(bb_code.linetable), + CC.copy(bb_code.argtypes), + CC.copy(bb_code.meta), + CC.copy(bb_code.sptypes), + ) + end end function _lower_switch_statements(bb_code::BBCode) diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 6d33439e..de14fd87 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -90,7 +90,9 @@ function build_callable(sig::Type{<:Tuple}) unoptimised_ir = IRCode(bb) optimised_ir = optimise_ir!(unoptimised_ir) mc_ret_type = callable_ret_type(sig, types) - mc = misty_closure(mc_ret_type, optimised_ir, refs...; isva=isva, do_compile=true) + mc = optimized_misty_closure( + mc_ret_type, optimised_ir, refs...; isva=isva, do_compile=true + ) mc_cache[key] = mc return mc, refs[end] end @@ -277,6 +279,13 @@ The above gives the broad outline of how `TapedTask`s are implemented. We refer readers to the code, which is extensively commented to explain implementation details. """ function TapedTask(taped_globals::Any, fargs...; kwargs...) + @static if v"1.12.1" > VERSION >= v"1.12.0-" + @warn """ + Libtask.jl does not work correctly on Julia v1.12.0 and may crash your Julia + session. Please upgrade to at least v1.12.1. See + https://github.com/JuliaLang/julia/issues/59222 for the bug in question. + """ + end all_args = isempty(kwargs) ? fargs : (Core.kwcall, getfield(kwargs, :data), fargs...) seed_id!() # a BBCode thing. mc, count_ref = build_callable(typeof(all_args)) @@ -441,8 +450,10 @@ get_value(x) = x expression, otherwise `false`. """ function is_produce_stmt(x)::Bool - if Meta.isexpr(x, :invoke) && length(x.args) == 3 && x.args[1] isa Core.MethodInstance - return x.args[1].specTypes <: Tuple{typeof(produce),Any} + if Meta.isexpr(x, :invoke) && + length(x.args) == 3 && + x.args[1] isa Union{Core.MethodInstance,Core.CodeInstance} + return get_mi(x.args[1]).specTypes <: Tuple{typeof(produce),Any} elseif Meta.isexpr(x, :call) && length(x.args) == 2 return get_value(x.args[1]) === produce else @@ -465,7 +476,7 @@ function stmt_might_produce(x, ret_type::Type)::Bool # Statement will terminate in the usual fashion, so _do_ bother recusing. is_produce_stmt(x) && return true - Meta.isexpr(x, :invoke) && return might_produce(x.args[1].specTypes) + Meta.isexpr(x, :invoke) && return might_produce(get_mi(x.args[1]).specTypes) if Meta.isexpr(x, :call) # This is a hack -- it's perfectly possible for `DataType` calls to produce in general. f = get_function(x.args[1]) @@ -1029,7 +1040,7 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}} # Derive TapedTask for this statement. (callable, callable_args) = if Meta.isexpr(stmt, :invoke) - sig = stmt.args[1].specTypes + sig = get_mi(stmt.args[1]).specTypes v = Any[Any] (LazyCallable{sig,callable_ret_type(sig, v)}(), stmt.args[2:end]) elseif Meta.isexpr(stmt, :call) @@ -1144,7 +1155,13 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}} new_argtypes = vcat(typeof(refs), copy(ir.argtypes)) # Return BBCode and the `Ref`s. - new_ir = BBCode(new_bblocks, new_argtypes, ir.sptypes, ir.linetable, ir.meta) + @static if VERSION >= v"1.12-" + new_ir = BBCode( + new_bblocks, new_argtypes, ir.sptypes, ir.debuginfo, ir.meta, ir.valid_worlds + ) + else + new_ir = BBCode(new_bblocks, new_argtypes, ir.sptypes, ir.linetable, ir.meta) + end return new_ir, refs, possible_produce_types end diff --git a/src/utils.jl b/src/utils.jl index e4449657..cd958ae1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1,7 @@ +function get_mi(ci::Core.CodeInstance) + @static isdefined(CC, :get_ci_mi) ? CC.get_ci_mi(ci) : ci.def +end +get_mi(mi::Core.MethodInstance) = mi """ replace_captures(oc::Toc, new_captures) where {Toc<:OpaqueClosure} @@ -68,7 +72,11 @@ function optimise_ir!(ir::IRCode; show_ir=false, do_inline=true) ir = CC.compact!(ir) # CC.verify_ir(ir, true, false, CC.optimizer_lattice(local_interp)) - CC.verify_linetable(ir.linetable, true) + @static if VERSION >= v"1.12-" + CC.verify_linetable(ir.debuginfo, div(length(ir.debuginfo.codelocs), 3), true) + else + CC.verify_linetable(ir.linetable, true) + end if show_ir println("Post-optimization") display(ir) @@ -96,13 +104,27 @@ end # Run type inference and constant propagation on the ir. Credit to @oxinabox: # https://gist.github.com/oxinabox/cdcffc1392f91a2f6d80b2524726d802#file-example-jl-L54 function __infer_ir!(ir, interp::CC.AbstractInterpreter, mi::CC.MethodInstance) - method_info = CC.MethodInfo(true, nothing) #=propagate_inbounds=# - min_world = world = get_inference_world(interp) - max_world = Base.get_world_counter() - irsv = CC.IRInterpretationState( - interp, method_info, ir, mi, ir.argtypes, world, min_world, max_world - ) - rt = CC._ir_abstract_constant_propagation(interp, irsv) + @static if VERSION >= v"1.12-" + nargs = length(ir.argtypes) - 1 + # TODO(mhauru) How should we figure out isva? I don't think it's in ir or mi. + isva = false + propagate_inbounds = true + spec_info = CC.SpecInfo(nargs, isva, propagate_inbounds, nothing) + min_world = world = get_inference_world(interp) + max_world = Base.get_world_counter() + irsv = CC.IRInterpretationState( + interp, spec_info, ir, mi, ir.argtypes, world, min_world, max_world + ) + rt = CC.ir_abstract_constant_propagation(interp, irsv) + else + method_info = CC.MethodInfo(true, nothing) #=propagate_inbounds=# + min_world = world = get_inference_world(interp) + max_world = Base.get_world_counter() + irsv = CC.IRInterpretationState( + interp, method_info, ir, mi, ir.argtypes, world, min_world, max_world + ) + rt = CC._ir_abstract_constant_propagation(interp, irsv) + end return ir end @@ -168,19 +190,85 @@ function opaque_closure( ) # This implementation is copied over directly from `Core.OpaqueClosure`. ir = CC.copy(ir) - nargs = length(ir.argtypes) - 1 - sig = Base.Experimental.compute_oc_signature(ir, nargs, isva) + @static if VERSION >= v"1.12-" + # On v1.12 OpaqueClosure expects the first arg to be the environment. + ir.argtypes[1] = typeof(env) + end + nargtypes = length(ir.argtypes) + nargs = nargtypes - 1 + @static if VERSION >= v"1.12-" + sig = CC.compute_oc_signature(ir, nargs, isva) + else + sig = Base.Experimental.compute_oc_signature(ir, nargs, isva) + end src = ccall(:jl_new_code_info_uninit, Ref{CC.CodeInfo}, ()) - src.slotnames = fill(:none, nargs + 1) - src.slotflags = fill(zero(UInt8), length(ir.argtypes)) + src.slotnames = [Symbol(:_, i) for i in 1:nargtypes] + src.slotflags = fill(zero(UInt8), nargtypes) src.slottypes = copy(ir.argtypes) - src.rettype = ret_type + @static if VERSION > v"1.12-" + ir.debuginfo.def === nothing && + (ir.debuginfo.def = :var"generated IR for OpaqueClosure") + src.min_world = ir.valid_worlds.min_world + src.max_world = ir.valid_worlds.max_world + src.isva = isva + src.nargs = nargtypes + end src = CC.ir_to_codeinf!(src, ir) + src.rettype = ret_type return Base.Experimental.generate_opaque_closure( sig, Union{}, ret_type, src, nargs, isva, env...; do_compile )::Core.OpaqueClosure{sig,ret_type} end +function optimized_opaque_closure(rtype, ir::IRCode, env...; kwargs...) + oc = opaque_closure(rtype, ir, env...; kwargs...) + world = UInt(oc.world) + set_world_bounds_for_optimization!(oc) + optimized_oc = optimize_opaque_closure(oc, rtype, env...; kwargs...) + return optimized_oc +end + +function optimize_opaque_closure(oc::Core.OpaqueClosure, rtype, env...; kwargs...) + method = oc.source + ci = method.specializations.cache + world = UInt(oc.world) + ir = reinfer_and_inline(ci, world) + ir === nothing && return oc # nothing to optimize + return opaque_closure(rtype, ir, env...; kwargs...) +end + +# Allows optimization to make assumptions about binding access, +# enabling inlining and other optimizations. +function set_world_bounds_for_optimization!(oc::Core.OpaqueClosure) + ci = oc.source.specializations.cache + ci.inferred === nothing && return nothing + ci.inferred.min_world = oc.world + return ci.inferred.max_world = oc.world +end + +function reinfer_and_inline(ci::Core.CodeInstance, world::UInt) + interp = CC.NativeInterpreter(world) + mi = get_mi(ci) + argtypes = collect(Any, mi.specTypes.parameters) + irsv = CC.IRInterpretationState(interp, ci, mi, argtypes, world) + irsv === nothing && return nothing + for stmt in irsv.ir.stmts + inst = stmt[:inst] + if Meta.isexpr(inst, :loopinfo) || + Meta.isexpr(inst, :pop_exception) || + isa(inst, CC.GotoIfNot) || + isa(inst, CC.GotoNode) || + Meta.isexpr(inst, :copyast) + continue + end + stmt[:flag] |= CC.IR_FLAG_REFINED + end + CC.ir_abstract_constant_propagation(interp, irsv) + state = CC.InliningState(interp) + ir = CC.ssa_inlining_pass!(irsv.ir, state, CC.propagate_inbounds(irsv)) + return ir +end + """ misty_closure( ret_type::Type, @@ -202,3 +290,15 @@ function misty_closure( ) return MistyClosure(opaque_closure(ret_type, ir, env...; isva, do_compile), Ref(ir)) end + +function optimized_misty_closure( + ret_type::Type, + ir::IRCode, + @nospecialize env...; + isva::Bool=false, + do_compile::Bool=true, +) + return MistyClosure( + optimized_opaque_closure(ret_type, ir, env...; isva, do_compile), Ref(ir) + ) +end