Skip to content

Commit

Permalink
infer more completely everything that the optimizer/codegen requires (#…
Browse files Browse the repository at this point in the history
…56565)

Inlining wants to know information about every isa_compileable_sig
method as well as everything it might consider inlining (which is almost
the same thing). So even if inference could bail on computing the type
since it already reached the maximum fixed point, it should keep going
to get that information. This now uses two loops here now: one to
compute the inference types information, then a second loop go back and
get coverage of all of the compileable targets (unless that particular
target is predicted to be inlined or dropped later).

(system image size contribution seems to be fairly negligible)
  • Loading branch information
vtjnash authored Nov 15, 2024
1 parent 5ec3215 commit caa2f7d
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 79 deletions.
148 changes: 84 additions & 64 deletions Compiler/src/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,24 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end

(; valid_worlds, applicable, info) = matches
update_valid_age!(sv, valid_worlds)
update_valid_age!(sv, valid_worlds) # need to record the negative world now, since even if we don't generate any useful information, inlining might want to add an invoke edge and it won't have this information anymore
if bail_out_toplevel_call(interp, sv)
napplicable = length(applicable)
for i = 1:napplicable
sig = applicable[i].match.spec_types
if !isdispatchtuple(sig)
# only infer fully concrete call sites in top-level expressions (ignoring even isa_compileable_sig matches)
add_remark!(interp, sv, "Refusing to infer non-concrete call site in top-level expression")
return Future(CallMeta(Any, Any, Effects(), NoCallInfo()))
end
end
end

# final result
gfresult = Future{CallMeta}()
# intermediate work for computing gfresult
rettype = exctype = Bottom
conditionals = nothing # keeps refinement information of call argument types when the return type is boolean
seenall = true
const_results = nothing # or const_results::Vector{Union{Nothing,ConstResult}} if any const results are available
fargs = arginfo.fargs
all_effects = EFFECTS_TOTAL
Expand All @@ -69,16 +79,14 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
f = Core.Box(f)
atype = Core.Box(atype)
function infercalls(interp, sv)
napplicable = length(applicable)
multiple_matches = napplicable > 1
local napplicable = length(applicable)
local multiple_matches = napplicable > 1
while i <= napplicable
(; match, edges, edge_idx) = applicable[i]
method = match.method
sig = match.spec_types
if bail_out_toplevel_call(interp, InferenceLoopState(sig, rettype, all_effects), sv)
# only infer concrete call sites in top-level expressions
add_remark!(interp, sv, "Refusing to infer non-concrete call site in top-level expression")
seenall = false
if bail_out_call(interp, InferenceLoopState(rettype, all_effects), sv)
add_remark!(interp, sv, "Call inference reached maximally imprecise information: bailing on doing more abstract inference.")
break
end
# TODO: this is unmaintained now as it didn't seem to improve things, though it does avoid hard-coding the union split at the higher level,
Expand Down Expand Up @@ -162,17 +170,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
Any[Bottom for _ in 1:length(argtypes)]
end
for i = 1:length(argtypes)
cnd = conditional_argtype(𝕃ᵢ, this_conditional, sig, argtypes, i)
cnd = conditional_argtype(𝕃ᵢ, this_conditional, match.spec_types, argtypes, i)
conditionals[1][i] = conditionals[1][i] ᵢ cnd.thentype
conditionals[2][i] = conditionals[2][i] ᵢ cnd.elsetype
end
end
edges[edge_idx] = edge
if i < napplicable && bail_out_call(interp, InferenceLoopState(sig, rettype, all_effects), sv)
add_remark!(interp, sv, "Call inference reached maximally imprecise information. Bailing on.")
seenall = false
i = napplicable # break in outer function
end

i += 1
return true
end # function handle1
Expand All @@ -184,12 +188,12 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
end # while

if const_results !== nothing
@assert napplicable == nmatches(info) == length(const_results)
info = ConstCallInfo(info, const_results)
end

if seenall
seenall = i > napplicable
if seenall # small optimization to skip some work that is already implied
if const_results !== nothing
@assert napplicable == nmatches(info) == length(const_results)
info = ConstCallInfo(info, const_results)
end
if !fully_covering(matches) || any_ambig(matches)
# Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
all_effects = Effects(all_effects; nothrow=false)
Expand All @@ -198,51 +202,67 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
if sv isa InferenceState && fargs !== nothing
slotrefinements = collect_slot_refinements(𝕃ᵢ, applicable, argtypes, fargs, sv)
end
rettype = from_interprocedural!(interp, rettype, sv, arginfo, conditionals)
if call_result_unused(si) && !(rettype === Bottom)
add_remark!(interp, sv, "Call result type was widened because the return value is unused")
# We're mainly only here because the optimizer might want this code,
# but we ourselves locally don't typically care about it locally
# (beyond checking if it always throws).
# So avoid adding an edge, since we don't want to bother attempting
# to improve our result even if it does change (to always throw),
# and avoid keeping track of a more complex result type.
rettype = Any
end
# if from_interprocedural added any pclimitations to the set inherited from the arguments,
# some of those may be part of our cycles, so those can be deleted now
# TODO: and those might need to be deleted later too if the cycle grows to include them?
if isa(sv, InferenceState)
# TODO (#48913) implement a proper recursion handling for irinterp:
# This works just because currently the `:terminate` condition guarantees that
# irinterp doesn't fail into unresolved cycles, but it's not a good solution.
# We should revisit this once we have a better story for handling cycles in irinterp.
if !isempty(sv.pclimitations) # remove self, if present
delete!(sv.pclimitations, sv)
for caller in callers_in_cycle(sv)
delete!(sv.pclimitations, caller)
end
end
end
else
# there is unanalyzed candidate, widen type and effects to the top
rettype = exctype = Any
all_effects = Effects()
const_results = nothing
end

rettype = from_interprocedural!(interp, rettype, sv, arginfo, conditionals)

# Also considering inferring the compilation signature for this method, so
# it is available to the compiler, unless it should not end up needing it (for an invoke).
if (isa(sv, InferenceState) && infer_compilation_signature(interp) &&
(seenall && 1 == napplicable) && (!is_removable_if_unused(all_effects) || !call_result_unused(si)))
(; match) = applicable[1]
method = match.method
sig = match.spec_types
mi = specialize_method(match; preexisting=true)
if mi === nothing || !const_prop_methodinstance_heuristic(interp, mi, arginfo, sv)
csig = get_compileable_sig(method, sig, match.sparams)
if csig !== nothing && csig !== sig
abstract_call_method(interp, method, csig, match.sparams, multiple_matches, StmtInfo(false), sv)::Future
end
end
end

if call_result_unused(si) && !(rettype === Bottom)
add_remark!(interp, sv, "Call result type was widened because the return value is unused")
# We're mainly only here because the optimizer might want this code,
# but we ourselves locally don't typically care about it locally
# (beyond checking if it always throws).
# So avoid adding an edge, since we don't want to bother attempting
# to improve our result even if it does change (to always throw),
# and avoid keeping track of a more complex result type.
rettype = Any
end
if isa(sv, InferenceState)
# TODO (#48913) implement a proper recursion handling for irinterp:
# This works just because currently the `:terminate` condition guarantees that
# irinterp doesn't fail into unresolved cycles, but it's not a good solution.
# We should revisit this once we have a better story for handling cycles in irinterp.
if !isempty(sv.pclimitations) # remove self, if present
delete!(sv.pclimitations, sv)
for caller in callers_in_cycle(sv)
delete!(sv.pclimitations, caller)
# it is available to the compiler in case it ends up needing it for the invoke.
if isa(sv, InferenceState) && infer_compilation_signature(interp) && (!is_removable_if_unused(all_effects) || !call_result_unused(si))
i = 1
function infercalls2(interp, sv)
local napplicable = length(applicable)
local multiple_matches = napplicable > 1
while i <= napplicable
(; match, edges, edge_idx) = applicable[i]
i += 1
method = match.method
sig = match.spec_types
mi = specialize_method(match; preexisting=true)
if mi === nothing || !const_prop_methodinstance_heuristic(interp, mi, arginfo, sv)
csig = get_compileable_sig(method, sig, match.sparams)
if csig !== nothing && (!seenall || csig !== sig) # corresponds to whether the first look already looked at this, so repeating abstract_call_method is not useful
sp_ = ccall(:jl_type_intersection_with_env, Any, (Any, Any), csig, method.sig)::SimpleVector
if match.sparams === sp_[2]
mresult = abstract_call_method(interp, method, csig, match.sparams, multiple_matches, StmtInfo(false), sv)::Future
isready(mresult) || return false # wait for mresult Future to resolve off the callstack before continuing
end
end
end
end
return true
end
# start making progress on the first call
infercalls2(interp, sv) || push!(sv.tasks, infercalls2)
end

gfresult[] = CallMeta(rettype, exctype, all_effects, info, slotrefinements)
Expand Down Expand Up @@ -1787,6 +1807,14 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si::
i = 1
while i <= length(ctypes)
ct = ctypes[i]
if bail_out_apply(interp, InferenceLoopState(res, all_effects), sv)
add_remark!(interp, sv, "_apply_iterate inference reached maximally imprecise information: bailing on analysis of more methods.")
# there is unanalyzed candidate, widen type and effects to the top
let retinfo = NoCallInfo() # NOTE this is necessary to prevent the inlining processing
applyresult[] = CallMeta(Any, Any, Effects(), retinfo)
return true
end
end
lct = length(ct)
# truncate argument list at the first Vararg
for k = 1:lct-1
Expand All @@ -1808,14 +1836,6 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si::
res = tmerge(typeinf_lattice(interp), res, rt)
exctype = tmerge(typeinf_lattice(interp), exctype, exct)
all_effects = merge_effects(all_effects, effects)
if i < length(ctypes) && bail_out_apply(interp, InferenceLoopState(ctypes[i], res, all_effects), sv)
add_remark!(interp, sv, "_apply_iterate inference reached maximally imprecise information. Bailing on.")
# there is unanalyzed candidate, widen type and effects to the top
let retinfo = NoCallInfo() # NOTE this is necessary to prevent the inlining processing
applyresult[] = CallMeta(Any, Any, Effects(), retinfo)
return true
end
end
end
i += 1
end
Expand Down
10 changes: 4 additions & 6 deletions Compiler/src/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1032,17 +1032,15 @@ decode_statement_effects_override(sv::AbsIntState) =
decode_statement_effects_override(get_curr_ssaflag(sv))

struct InferenceLoopState
sig
rt
effects::Effects
function InferenceLoopState(@nospecialize(sig), @nospecialize(rt), effects::Effects)
new(sig, rt, effects)
function InferenceLoopState(@nospecialize(rt), effects::Effects)
new(rt, effects)
end
end

bail_out_toplevel_call(::AbstractInterpreter, state::InferenceLoopState, sv::InferenceState) =
sv.restrict_abstract_call_sites && !isdispatchtuple(state.sig)
bail_out_toplevel_call(::AbstractInterpreter, ::InferenceLoopState, ::IRInterpretationState) = false
bail_out_toplevel_call(::AbstractInterpreter, sv::InferenceState) = sv.restrict_abstract_call_sites
bail_out_toplevel_call(::AbstractInterpreter, ::IRInterpretationState) = false

bail_out_call(::AbstractInterpreter, state::InferenceLoopState, ::InferenceState) =
state.rt === Any && !is_foldable(state.effects)
Expand Down
13 changes: 5 additions & 8 deletions Compiler/test/AbstractInterpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,15 @@ end |> !Core.Compiler.is_nonoverlayed

# account for overlay possibility in unanalyzed matching method
callstrange(::Float64) = strangesin(x)
callstrange(::Nothing) = Core.compilerbarrier(:type, nothing) # trigger inference bail out
callstrange(::Number) = Core.compilerbarrier(:type, nothing) # trigger inference bail out
callstrange(::Any) = 1.0
callstrange_entry(x) = callstrange(x) # needs to be defined here because of world age
let interp = MTOverlayInterp(Set{Any}())
matches = Core.Compiler.findall(Tuple{typeof(callstrange),Any}, Core.Compiler.method_table(interp))
@test matches !== nothing
@test Core.Compiler.length(matches) == 2
if Core.Compiler.getindex(matches, 1).method == which(callstrange, (Nothing,))
@test Base.infer_effects(callstrange_entry, (Any,); interp) |> !Core.Compiler.is_nonoverlayed
@test "Call inference reached maximally imprecise information. Bailing on." in interp.meta
else
@warn "`nonoverlayed` test for inference bailing out is skipped since the method match sort order is changed."
end
@test Core.Compiler.length(matches) == 3
@test Base.infer_effects(callstrange_entry, (Any,); interp) |> !Core.Compiler.is_nonoverlayed
@test "Call inference reached maximally imprecise information: bailing on doing more abstract inference." in interp.meta
end

# but it should never apply for the native compilation
Expand Down
2 changes: 1 addition & 1 deletion Compiler/test/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4114,7 +4114,7 @@ callsig_backprop_any(::Any) = nothing
callsig_backprop_lhs(::Int) = nothing
callsig_backprop_bailout(::Val{0}) = 0
callsig_backprop_bailout(::Val{1}) = undefvar # undefvar::Any triggers `bail_out_call`
callsig_backprop_bailout(::Val{2}) = 2
callsig_backprop_bailout(::Val) = 2
callsig_backprop_addinteger(a::Integer, b::Integer) = a + b # results in too many matching methods and triggers `bail_out_call`)
@test Base.infer_return_type(callsig_backprop_addinteger) == Any
let effects = Base.infer_effects(callsig_backprop_addinteger)
Expand Down

0 comments on commit caa2f7d

Please sign in to comment.