Skip to content

Commit 8066c29

Browse files
authored
AbstractInterpreter: add a hook to customize bestguess calculation (#50744)
Currently, the code that updates `bestguess` using `ReturnNode` information includes hardcodes that relate to `Conditional` and `LimitedAccuracy`. These behaviors are actually lattice-dependent and therefore should be overloadable by `AbstractInterpreter`. Additionally, particularly in Diffractor, a clever strategy is required to update return types in a way that it takes into account information from both the original method and its rule method (xref: JuliaDiff/Diffractor.jl#202). This also requires such an overload to exist. In response to these needs, this commit introduces an implementation of a hook named `update_bestguess!`.
1 parent 9822257 commit 8066c29

File tree

2 files changed

+58
-48
lines changed

2 files changed

+58
-48
lines changed

base/compiler/abstractinterpretation.jl

+37-30
Original file line numberDiff line numberDiff line change
@@ -2887,17 +2887,49 @@ function init_vartable!(vartable::VarTable, frame::InferenceState)
28872887
return vartable
28882888
end
28892889

2890+
function update_bestguess!(interp::AbstractInterpreter, frame::InferenceState,
2891+
currstate::VarTable, @nospecialize(rt))
2892+
bestguess = frame.bestguess
2893+
nargs = narguments(frame, #=include_va=#false)
2894+
slottypes = frame.slottypes
2895+
rt = widenreturn(rt, BestguessInfo(interp, bestguess, nargs, slottypes, currstate))
2896+
# narrow representation of bestguess slightly to prepare for tmerge with rt
2897+
if rt isa InterConditional && bestguess isa Const
2898+
slot_id = rt.slot
2899+
old_id_type = slottypes[slot_id]
2900+
if bestguess.val === true && rt.elsetype !== Bottom
2901+
bestguess = InterConditional(slot_id, old_id_type, Bottom)
2902+
elseif bestguess.val === false && rt.thentype !== Bottom
2903+
bestguess = InterConditional(slot_id, Bottom, old_id_type)
2904+
end
2905+
end
2906+
# copy limitations to return value
2907+
if !isempty(frame.pclimitations)
2908+
union!(frame.limitations, frame.pclimitations)
2909+
empty!(frame.pclimitations)
2910+
end
2911+
if !isempty(frame.limitations)
2912+
rt = LimitedAccuracy(rt, copy(frame.limitations))
2913+
end
2914+
𝕃ₚ = ipo_lattice(interp)
2915+
if !(𝕃ₚ, rt, bestguess)
2916+
# TODO: if bestguess isa InterConditional && !interesting(bestguess); bestguess = widenconditional(bestguess); end
2917+
frame.bestguess = tmerge(𝕃ₚ, bestguess, rt) # new (wider) return type for frame
2918+
return true
2919+
else
2920+
return false
2921+
end
2922+
end
2923+
28902924
# make as much progress on `frame` as possible (without handling cycles)
28912925
function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
28922926
@assert !is_inferred(frame)
28932927
frame.dont_work_on_me = true # mark that this function is currently on the stack
28942928
W = frame.ip
2895-
nargs = narguments(frame, #=include_va=#false)
2896-
slottypes = frame.slottypes
28972929
ssavaluetypes = frame.ssavaluetypes
28982930
bbs = frame.cfg.blocks
28992931
nbbs = length(bbs)
2900-
𝕃ₚ, 𝕃ᵢ = ipo_lattice(interp), typeinf_lattice(interp)
2932+
𝕃ᵢ = typeinf_lattice(interp)
29012933

29022934
currbb = frame.currbb
29032935
if currbb != 1
@@ -2998,35 +3030,10 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
29983030
end
29993031
end
30003032
elseif isa(stmt, ReturnNode)
3001-
bestguess = frame.bestguess
30023033
rt = abstract_eval_value(interp, stmt.val, currstate, frame)
3003-
rt = widenreturn(rt, BestguessInfo(interp, bestguess, nargs, slottypes, currstate))
3004-
# narrow representation of bestguess slightly to prepare for tmerge with rt
3005-
if rt isa InterConditional && bestguess isa Const
3006-
let slot_id = rt.slot
3007-
old_id_type = slottypes[slot_id]
3008-
if bestguess.val === true && rt.elsetype !== Bottom
3009-
bestguess = InterConditional(slot_id, old_id_type, Bottom)
3010-
elseif bestguess.val === false && rt.thentype !== Bottom
3011-
bestguess = InterConditional(slot_id, Bottom, old_id_type)
3012-
end
3013-
end
3014-
end
3015-
# copy limitations to return value
3016-
if !isempty(frame.pclimitations)
3017-
union!(frame.limitations, frame.pclimitations)
3018-
empty!(frame.pclimitations)
3019-
end
3020-
if !isempty(frame.limitations)
3021-
rt = LimitedAccuracy(rt, copy(frame.limitations))
3022-
end
3023-
if !(𝕃ₚ, rt, bestguess)
3024-
# new (wider) return type for frame
3025-
bestguess = tmerge(𝕃ₚ, bestguess, rt)
3026-
# TODO: if bestguess isa InterConditional && !interesting(bestguess); bestguess = widenconditional(bestguess); end
3027-
frame.bestguess = bestguess
3034+
if update_bestguess!(interp, frame, currstate, rt)
30283035
for (caller, caller_pc) in frame.cycle_backedges
3029-
if !(caller.ssavaluetypes[caller_pc] === Any)
3036+
if caller.ssavaluetypes[caller_pc] !== Any
30303037
# no reason to revisit if that call-site doesn't affect the final result
30313038
push!(caller.ip, block_for_inst(caller.cfg, caller_pc))
30323039
end

base/compiler/typeinfer.jl

+21-18
Original file line numberDiff line numberDiff line change
@@ -870,26 +870,10 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
870870
# since the inliner will request to use it later
871871
cache = :local
872872
else
873+
rt = cached_return_type(code)
873874
effects = ipo_effects(code)
874875
update_valid_age!(caller, WorldRange(min_world(code), max_world(code)))
875-
rettype = code.rettype
876-
if isdefined(code, :rettype_const)
877-
rettype_const = code.rettype_const
878-
# the second subtyping/egal conditions are necessary to distinguish usual cases
879-
# from rare cases when `Const` wrapped those extended lattice type objects
880-
if isa(rettype_const, Vector{Any}) && !(Vector{Any} <: rettype)
881-
rettype = PartialStruct(rettype, rettype_const)
882-
elseif isa(rettype_const, PartialOpaque) && rettype <: Core.OpaqueClosure
883-
rettype = rettype_const
884-
elseif isa(rettype_const, InterConditional) && rettype !== InterConditional
885-
rettype = rettype_const
886-
elseif isa(rettype_const, InterMustAlias) && rettype !== InterMustAlias
887-
rettype = rettype_const
888-
else
889-
rettype = Const(rettype_const)
890-
end
891-
end
892-
return EdgeCallResult(rettype, mi, effects)
876+
return EdgeCallResult(rt, mi, effects)
893877
end
894878
else
895879
cache = :global # cache edge targets by default
@@ -933,6 +917,25 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
933917
return EdgeCallResult(frame.bestguess, nothing, adjust_effects(frame))
934918
end
935919

920+
function cached_return_type(code::CodeInstance)
921+
rettype = code.rettype
922+
isdefined(code, :rettype_const) || return rettype
923+
rettype_const = code.rettype_const
924+
# the second subtyping/egal conditions are necessary to distinguish usual cases
925+
# from rare cases when `Const` wrapped those extended lattice type objects
926+
if isa(rettype_const, Vector{Any}) && !(Vector{Any} <: rettype)
927+
return PartialStruct(rettype, rettype_const)
928+
elseif isa(rettype_const, PartialOpaque) && rettype <: Core.OpaqueClosure
929+
return rettype_const
930+
elseif isa(rettype_const, InterConditional) && rettype !== InterConditional
931+
return rettype_const
932+
elseif isa(rettype_const, InterMustAlias) && rettype !== InterMustAlias
933+
return rettype_const
934+
else
935+
return Const(rettype_const)
936+
end
937+
end
938+
936939
#### entry points for inferring a MethodInstance given a type signature ####
937940

938941
# compute an inferred AST and return type

0 commit comments

Comments
 (0)