From caa2f7d52b430f50c8038a7f6766edba28a3fb65 Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Fri, 15 Nov 2024 11:44:35 -0500 Subject: [PATCH] infer more completely everything that the optimizer/codegen requires (#56565) Inlining wants to know information about every isa_compileable_sig method as well as everything it might consider inlining (which is almost the same thing). So even if inference could bail on computing the type since it already reached the maximum fixed point, it should keep going to get that information. This now uses two loops here now: one to compute the inference types information, then a second loop go back and get coverage of all of the compileable targets (unless that particular target is predicted to be inlined or dropped later). (system image size contribution seems to be fairly negligible) --- Compiler/src/abstractinterpretation.jl | 148 ++++++++++++++----------- Compiler/src/inferencestate.jl | 10 +- Compiler/test/AbstractInterpreter.jl | 13 +-- Compiler/test/inference.jl | 2 +- 4 files changed, 94 insertions(+), 79 deletions(-) diff --git a/Compiler/src/abstractinterpretation.jl b/Compiler/src/abstractinterpretation.jl index 093c5889f809e..f98b9336d97c0 100644 --- a/Compiler/src/abstractinterpretation.jl +++ b/Compiler/src/abstractinterpretation.jl @@ -51,14 +51,24 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), end (; valid_worlds, applicable, info) = matches - update_valid_age!(sv, valid_worlds) + update_valid_age!(sv, valid_worlds) # need to record the negative world now, since even if we don't generate any useful information, inlining might want to add an invoke edge and it won't have this information anymore + if bail_out_toplevel_call(interp, sv) + napplicable = length(applicable) + for i = 1:napplicable + sig = applicable[i].match.spec_types + if !isdispatchtuple(sig) + # only infer fully concrete call sites in top-level expressions (ignoring even isa_compileable_sig matches) + add_remark!(interp, sv, "Refusing to infer non-concrete call site in top-level expression") + return Future(CallMeta(Any, Any, Effects(), NoCallInfo())) + end + end + end # final result gfresult = Future{CallMeta}() # intermediate work for computing gfresult rettype = exctype = Bottom conditionals = nothing # keeps refinement information of call argument types when the return type is boolean - seenall = true const_results = nothing # or const_results::Vector{Union{Nothing,ConstResult}} if any const results are available fargs = arginfo.fargs all_effects = EFFECTS_TOTAL @@ -69,16 +79,14 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), f = Core.Box(f) atype = Core.Box(atype) function infercalls(interp, sv) - napplicable = length(applicable) - multiple_matches = napplicable > 1 + local napplicable = length(applicable) + local multiple_matches = napplicable > 1 while i <= napplicable (; match, edges, edge_idx) = applicable[i] method = match.method sig = match.spec_types - if bail_out_toplevel_call(interp, InferenceLoopState(sig, rettype, all_effects), sv) - # only infer concrete call sites in top-level expressions - add_remark!(interp, sv, "Refusing to infer non-concrete call site in top-level expression") - seenall = false + if bail_out_call(interp, InferenceLoopState(rettype, all_effects), sv) + add_remark!(interp, sv, "Call inference reached maximally imprecise information: bailing on doing more abstract inference.") break end # TODO: this is unmaintained now as it didn't seem to improve things, though it does avoid hard-coding the union split at the higher level, @@ -162,17 +170,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), Any[Bottom for _ in 1:length(argtypes)] end for i = 1:length(argtypes) - cnd = conditional_argtype(𝕃ᵢ, this_conditional, sig, argtypes, i) + cnd = conditional_argtype(𝕃ᵢ, this_conditional, match.spec_types, argtypes, i) conditionals[1][i] = conditionals[1][i] ⊔ᵢ cnd.thentype conditionals[2][i] = conditionals[2][i] ⊔ᵢ cnd.elsetype end end edges[edge_idx] = edge - if i < napplicable && bail_out_call(interp, InferenceLoopState(sig, rettype, all_effects), sv) - add_remark!(interp, sv, "Call inference reached maximally imprecise information. Bailing on.") - seenall = false - i = napplicable # break in outer function - end + i += 1 return true end # function handle1 @@ -184,12 +188,12 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), end end # while - if const_results !== nothing - @assert napplicable == nmatches(info) == length(const_results) - info = ConstCallInfo(info, const_results) - end - - if seenall + seenall = i > napplicable + if seenall # small optimization to skip some work that is already implied + if const_results !== nothing + @assert napplicable == nmatches(info) == length(const_results) + info = ConstCallInfo(info, const_results) + end if !fully_covering(matches) || any_ambig(matches) # Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature. all_effects = Effects(all_effects; nothrow=false) @@ -198,51 +202,67 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), if sv isa InferenceState && fargs !== nothing slotrefinements = collect_slot_refinements(𝕃ᵢ, applicable, argtypes, fargs, sv) end + rettype = from_interprocedural!(interp, rettype, sv, arginfo, conditionals) + if call_result_unused(si) && !(rettype === Bottom) + add_remark!(interp, sv, "Call result type was widened because the return value is unused") + # We're mainly only here because the optimizer might want this code, + # but we ourselves locally don't typically care about it locally + # (beyond checking if it always throws). + # So avoid adding an edge, since we don't want to bother attempting + # to improve our result even if it does change (to always throw), + # and avoid keeping track of a more complex result type. + rettype = Any + end + # if from_interprocedural added any pclimitations to the set inherited from the arguments, + # some of those may be part of our cycles, so those can be deleted now + # TODO: and those might need to be deleted later too if the cycle grows to include them? + if isa(sv, InferenceState) + # TODO (#48913) implement a proper recursion handling for irinterp: + # This works just because currently the `:terminate` condition guarantees that + # irinterp doesn't fail into unresolved cycles, but it's not a good solution. + # We should revisit this once we have a better story for handling cycles in irinterp. + if !isempty(sv.pclimitations) # remove self, if present + delete!(sv.pclimitations, sv) + for caller in callers_in_cycle(sv) + delete!(sv.pclimitations, caller) + end + end + end else # there is unanalyzed candidate, widen type and effects to the top rettype = exctype = Any all_effects = Effects() + const_results = nothing end - rettype = from_interprocedural!(interp, rettype, sv, arginfo, conditionals) - # Also considering inferring the compilation signature for this method, so - # it is available to the compiler, unless it should not end up needing it (for an invoke). - if (isa(sv, InferenceState) && infer_compilation_signature(interp) && - (seenall && 1 == napplicable) && (!is_removable_if_unused(all_effects) || !call_result_unused(si))) - (; match) = applicable[1] - method = match.method - sig = match.spec_types - mi = specialize_method(match; preexisting=true) - if mi === nothing || !const_prop_methodinstance_heuristic(interp, mi, arginfo, sv) - csig = get_compileable_sig(method, sig, match.sparams) - if csig !== nothing && csig !== sig - abstract_call_method(interp, method, csig, match.sparams, multiple_matches, StmtInfo(false), sv)::Future - end - end - end - - if call_result_unused(si) && !(rettype === Bottom) - add_remark!(interp, sv, "Call result type was widened because the return value is unused") - # We're mainly only here because the optimizer might want this code, - # but we ourselves locally don't typically care about it locally - # (beyond checking if it always throws). - # So avoid adding an edge, since we don't want to bother attempting - # to improve our result even if it does change (to always throw), - # and avoid keeping track of a more complex result type. - rettype = Any - end - if isa(sv, InferenceState) - # TODO (#48913) implement a proper recursion handling for irinterp: - # This works just because currently the `:terminate` condition guarantees that - # irinterp doesn't fail into unresolved cycles, but it's not a good solution. - # We should revisit this once we have a better story for handling cycles in irinterp. - if !isempty(sv.pclimitations) # remove self, if present - delete!(sv.pclimitations, sv) - for caller in callers_in_cycle(sv) - delete!(sv.pclimitations, caller) + # it is available to the compiler in case it ends up needing it for the invoke. + if isa(sv, InferenceState) && infer_compilation_signature(interp) && (!is_removable_if_unused(all_effects) || !call_result_unused(si)) + i = 1 + function infercalls2(interp, sv) + local napplicable = length(applicable) + local multiple_matches = napplicable > 1 + while i <= napplicable + (; match, edges, edge_idx) = applicable[i] + i += 1 + method = match.method + sig = match.spec_types + mi = specialize_method(match; preexisting=true) + if mi === nothing || !const_prop_methodinstance_heuristic(interp, mi, arginfo, sv) + csig = get_compileable_sig(method, sig, match.sparams) + if csig !== nothing && (!seenall || csig !== sig) # corresponds to whether the first look already looked at this, so repeating abstract_call_method is not useful + sp_ = ccall(:jl_type_intersection_with_env, Any, (Any, Any), csig, method.sig)::SimpleVector + if match.sparams === sp_[2] + mresult = abstract_call_method(interp, method, csig, match.sparams, multiple_matches, StmtInfo(false), sv)::Future + isready(mresult) || return false # wait for mresult Future to resolve off the callstack before continuing + end + end + end end + return true end + # start making progress on the first call + infercalls2(interp, sv) || push!(sv.tasks, infercalls2) end gfresult[] = CallMeta(rettype, exctype, all_effects, info, slotrefinements) @@ -1787,6 +1807,14 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si:: i = 1 while i <= length(ctypes) ct = ctypes[i] + if bail_out_apply(interp, InferenceLoopState(res, all_effects), sv) + add_remark!(interp, sv, "_apply_iterate inference reached maximally imprecise information: bailing on analysis of more methods.") + # there is unanalyzed candidate, widen type and effects to the top + let retinfo = NoCallInfo() # NOTE this is necessary to prevent the inlining processing + applyresult[] = CallMeta(Any, Any, Effects(), retinfo) + return true + end + end lct = length(ct) # truncate argument list at the first Vararg for k = 1:lct-1 @@ -1808,14 +1836,6 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si:: res = tmerge(typeinf_lattice(interp), res, rt) exctype = tmerge(typeinf_lattice(interp), exctype, exct) all_effects = merge_effects(all_effects, effects) - if i < length(ctypes) && bail_out_apply(interp, InferenceLoopState(ctypes[i], res, all_effects), sv) - add_remark!(interp, sv, "_apply_iterate inference reached maximally imprecise information. Bailing on.") - # there is unanalyzed candidate, widen type and effects to the top - let retinfo = NoCallInfo() # NOTE this is necessary to prevent the inlining processing - applyresult[] = CallMeta(Any, Any, Effects(), retinfo) - return true - end - end end i += 1 end diff --git a/Compiler/src/inferencestate.jl b/Compiler/src/inferencestate.jl index fd421af733943..0ba37888b34d5 100644 --- a/Compiler/src/inferencestate.jl +++ b/Compiler/src/inferencestate.jl @@ -1032,17 +1032,15 @@ decode_statement_effects_override(sv::AbsIntState) = decode_statement_effects_override(get_curr_ssaflag(sv)) struct InferenceLoopState - sig rt effects::Effects - function InferenceLoopState(@nospecialize(sig), @nospecialize(rt), effects::Effects) - new(sig, rt, effects) + function InferenceLoopState(@nospecialize(rt), effects::Effects) + new(rt, effects) end end -bail_out_toplevel_call(::AbstractInterpreter, state::InferenceLoopState, sv::InferenceState) = - sv.restrict_abstract_call_sites && !isdispatchtuple(state.sig) -bail_out_toplevel_call(::AbstractInterpreter, ::InferenceLoopState, ::IRInterpretationState) = false +bail_out_toplevel_call(::AbstractInterpreter, sv::InferenceState) = sv.restrict_abstract_call_sites +bail_out_toplevel_call(::AbstractInterpreter, ::IRInterpretationState) = false bail_out_call(::AbstractInterpreter, state::InferenceLoopState, ::InferenceState) = state.rt === Any && !is_foldable(state.effects) diff --git a/Compiler/test/AbstractInterpreter.jl b/Compiler/test/AbstractInterpreter.jl index a49647ad4ea43..1939f4a19c05f 100644 --- a/Compiler/test/AbstractInterpreter.jl +++ b/Compiler/test/AbstractInterpreter.jl @@ -70,18 +70,15 @@ end |> !Core.Compiler.is_nonoverlayed # account for overlay possibility in unanalyzed matching method callstrange(::Float64) = strangesin(x) -callstrange(::Nothing) = Core.compilerbarrier(:type, nothing) # trigger inference bail out +callstrange(::Number) = Core.compilerbarrier(:type, nothing) # trigger inference bail out +callstrange(::Any) = 1.0 callstrange_entry(x) = callstrange(x) # needs to be defined here because of world age let interp = MTOverlayInterp(Set{Any}()) matches = Core.Compiler.findall(Tuple{typeof(callstrange),Any}, Core.Compiler.method_table(interp)) @test matches !== nothing - @test Core.Compiler.length(matches) == 2 - if Core.Compiler.getindex(matches, 1).method == which(callstrange, (Nothing,)) - @test Base.infer_effects(callstrange_entry, (Any,); interp) |> !Core.Compiler.is_nonoverlayed - @test "Call inference reached maximally imprecise information. Bailing on." in interp.meta - else - @warn "`nonoverlayed` test for inference bailing out is skipped since the method match sort order is changed." - end + @test Core.Compiler.length(matches) == 3 + @test Base.infer_effects(callstrange_entry, (Any,); interp) |> !Core.Compiler.is_nonoverlayed + @test "Call inference reached maximally imprecise information: bailing on doing more abstract inference." in interp.meta end # but it should never apply for the native compilation diff --git a/Compiler/test/inference.jl b/Compiler/test/inference.jl index 8a14774e2404f..e6bbf05caeabe 100644 --- a/Compiler/test/inference.jl +++ b/Compiler/test/inference.jl @@ -4114,7 +4114,7 @@ callsig_backprop_any(::Any) = nothing callsig_backprop_lhs(::Int) = nothing callsig_backprop_bailout(::Val{0}) = 0 callsig_backprop_bailout(::Val{1}) = undefvar # undefvar::Any triggers `bail_out_call` -callsig_backprop_bailout(::Val{2}) = 2 +callsig_backprop_bailout(::Val) = 2 callsig_backprop_addinteger(a::Integer, b::Integer) = a + b # results in too many matching methods and triggers `bail_out_call`) @test Base.infer_return_type(callsig_backprop_addinteger) == Any let effects = Base.infer_effects(callsig_backprop_addinteger)