Skip to content

Commit

Permalink
Extremely WIP/RFC: Extend invoke to accept CodeInstance
Browse files Browse the repository at this point in the history
This is an alternative mechanism to #56650 that largely achieves
the same result, but by hooking into `invoke` rather than a generated
function. They are orthogonal mechanisms, and its possible we want both.
However, in #56650, both Jameson and Valentin were skeptical of the
generated function signature bottleneck. This PR is sort of a hybrid
of mechanism in #52964 and what I proposed in #56650 (comment).

In particular, this PR:

1. Extends `invoke` to support a CodeInstance in place of its usual
   `types` argument.

2. Adds a new `typeinf` optimized generic. The semantics of this optimized
   generic allow the compiler to instead call a companion `typeinf_edge`
   function, allowing a mid-inference interpreter switch (like #52964),
   without being forced through a concrete signature bottleneck. However,
   if calling `typeinf_edge` does not work (e.g. because the compiler
   version is mismatched), this still has well defined semantics, you
   just don't get inference support.

The additional benefit of the `typeinf` optimized generic is that it lets
custom cache owners tell the runtime how to "cure" code instances that
have lost their native code. Currently the runtime only knows how to
do that for `owner == nothing` CodeInstances (by re-running inference).
This extension is not implemented, but the idea is that the runtime would
be permitted to call the `typeinf` optimized generic on the dead
CodeInstance's `owner` and `def` fields to obtain a cured CodeInstance (or
a user-actionable error from the plugin).

This PR includes an implementation of `with_new_compiler` from #56650.
This PR includes just enough compiler support to make the compiler
optimize this to the same code that #56650 produced:

```
julia> @code_typed with_new_compiler(sin, 1.0)
CodeInfo(
1 ─      $(Expr(:foreigncall, :(:jl_get_tls_world_age), UInt64, svec(), 0, :(:ccall)))::UInt64
│   %2 =   builtin Core.getfield(args, 1)::Float64
│   %3 =    invoke sin(%2::Float64)::Float64
└──      return %3
) => Float64
```

However, the implementation here is extremely incomplete. I'm putting
it up only as a directional sketch to see if people prefer it over #56650.
If so, I would prepare a cleaned up version of this PR that has the
optimized generics as well as the curing support, but not the full
inference integration (which needs a fair bit more work).
  • Loading branch information
Keno committed Nov 23, 2024
1 parent 1bf2ef9 commit 6521b62
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 7 deletions.
15 changes: 15 additions & 0 deletions Compiler/extras/CompilerDevTools/Manifest.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.12.0-DEV"
manifest_format = "2.0"
project_hash = "84f495a1bf065c95f732a48af36dd0cd2cefb9d5"

[[deps.Compiler]]
path = "../.."
uuid = "807dbc54-b67e-4c79-8afb-eafe4df6f2e1"
version = "0.0.2"

[[deps.CompilerDevTools]]
path = "."
uuid = "92b2d91f-d2bd-4c05-9214-4609ac33433f"
version = "0.0.0"
5 changes: 5 additions & 0 deletions Compiler/extras/CompilerDevTools/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
name = "CompilerDevTools"
uuid = "92b2d91f-d2bd-4c05-9214-4609ac33433f"

[deps]
Compiler = "807dbc54-b67e-4c79-8afb-eafe4df6f2e1"
45 changes: 45 additions & 0 deletions Compiler/extras/CompilerDevTools/src/CompilerDevTools.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
__precompile__(false)
module CompilerDevTools

using Compiler
using Core.IR

include(joinpath(dirname(pathof(Compiler)), "..", "test", "newinterp.jl"))

@newinterp SplitCacheInterp
struct SplitCacheOwner; end

import Core.OptimizedGenerics.CompilerPlugins: typeinf, typeinf_edge

Compiler.cache_owner(::SplitCacheInterp) = SplitCacheOwner()
let typeinf_world_age = Base.tls_world_age()
@eval @noinline typeinf(::SplitCacheOwner, mi::MethodInstance, source_mode::UInt8) =
Base.invoke_in_world($typeinf_world_age, Compiler.typeinf_ext, SplitCacheInterp(; world=Base.tls_world_age()), mi, source_mode)

@eval @noinline function typeinf_edge(::SplitCacheOwner, mi::MethodInstance, parent_frame::Compiler.InferenceState, world::UInt, source_mode::UInt8)
# TODO: This isn't quite right, we're just sketching things for now
interp = SplitCacheInterp(; world)
Compiler.typeinf_edge(interp, mi.def, mi.specTypes, Core.svec(), parent_frame, false, false)
end
end

# TODO: This needs special compiler support to properly case split for multiple
# method matches, etc. This annotation is not sound, but just for demo purpoes.
@Base.assume_effects :total @noinline function mi_for_tt(tt, world=Base.tls_world_age())
interp = SplitCacheInterp(; world)
match, _ = Compiler.findsup(tt, Compiler.method_table(interp))
Base.specialize_method(match)
end

function with_new_compiler(f, args...)
tt = Base.signature_type(f, typeof(args))
world = Base.tls_world_age()
new_compiler_ci = Core.OptimizedGenerics.CompilerPlugins.typeinf(
SplitCacheOwner(), mi_for_tt(tt), Compiler.SOURCE_MODE_ABI
)
invoke(f, new_compiler_ci, args...)
end

export with_new_compiler

end
25 changes: 25 additions & 0 deletions Compiler/src/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2218,6 +2218,15 @@ function abstract_invoke(interp::AbstractInterpreter, arginfo::ArgInfo, si::Stmt
ft′ = argtype_by_index(argtypes, 2)
ft = widenconst(ft′)
ft === Bottom && return Future(CallMeta(Bottom, Any, EFFECTS_THROWS, NoCallInfo()))
typesarg = argtype_by_index(argtypes, 3)
if hasintersect(widenconst(typesarg), CodeInstance)
if isa(typesarg, Const) && isa(typesarg.val, CodeInstance)
ci = typesarg.val
return Future(CallMeta(ci.rettype, ci.exctype,
decode_effects(ci.ipo_purity_bits), NoCallInfo()))
end
return Future(CallMeta(Any, Any, Effects(), NoCallInfo()))
end
(types, isexact, isconcrete, istype) = instanceof_tfunc(argtype_by_index(argtypes, 3), false)
isexact || return Future(CallMeta(Any, Any, Effects(), NoCallInfo()))
unwrapped = unwrap_unionall(types)
Expand Down Expand Up @@ -2670,6 +2679,22 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
end
elseif is_return_type(f)
return return_type_tfunc(interp, argtypes, si, sv)
elseif f === Core.OptimizedGenerics.CompilerPlugins.typeinf
mresult = try
invokelatest(Core.OptimizedGenerics.CompilerPlugins.typeinf_edge,
(argtypes[2]::Const).val,
(argtypes[3]::Const).val,
sv,
get_inference_world(interp),
(argtypes[4]::Const).val)
catch
return Future(CallMeta(Any, Any, EFFECTS_UNKNOWN, NoCallInfo()))
end
return Future{CallMeta}(mresult, interp, sv) do mresult, interp, sv
update_valid_age!(sv, WorldRange(mresult.edge.min_world, mresult.edge.max_world))
# TODO: `Const` isn't right here - we need a special lattice element that treats the IPO and non-IPO bits separately
return CallMeta(Const(mresult.edge), nothing, EFFECTS_TOTAL, MethodResultPure())
end
elseif la == 3 && f === Core.:(!==)
# mark !== as exactly a negated call to ===
let callfuture = abstract_call_gf_by_type(interp, f, ArgInfo(fargs, Any[Const(f), Any, Any]), si, Tuple{typeof(f), Any, Any}, sv, max_methods)::Future,
Expand Down
10 changes: 9 additions & 1 deletion Compiler/src/bootstrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
# especially try to make sure any recursive and leaf functions have concrete signatures,
# since we won't be able to specialize & infer them at runtime

activate_codegen!() = ccall(:jl_set_typeinf_func, Cvoid, (Any,), typeinf_ext_toplevel)
function activate_codegen!()
ccall(:jl_set_typeinf_func, Cvoid, (Any,), typeinf_ext_toplevel)
Core.eval(Compiler, quote
let typeinf_world_age = Base.tls_world_age()
@eval Core.OptimizedGenerics.CompilerPlugins.typeinf(::Nothing, mi::MethodInstance, source_mode::UInt8) =
Base.invoke_in_world($(Expr(:$, :typeinf_world_age)), typeinf_ext_toplevel, mi, Base.tls_world_age(), source_mode)
end
end)
end

function bootstrap!()
let time() = ccall(:jl_clock_now, Float64, ())
Expand Down
2 changes: 2 additions & 0 deletions Compiler/src/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1704,6 +1704,8 @@ function early_inline_special_case(ir::IRCode, stmt::Expr, flag::UInt32,
elseif (optimizer_lattice(state.interp), cond, Bool) && stmt.args[3] === stmt.args[4]
return SomeCase(stmt.args[3])
end
elseif f === Core.invoke && isa(argtypes[3], Const) && isa(argtypes[3].val, CodeInstance)
return SomeCase(Expr(:invoke, argtypes[3].val, stmt.args[2], stmt.args[4:end]...))
end
return nothing
end
Expand Down
27 changes: 27 additions & 0 deletions base/optimized_generics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,31 @@ module KeyValue
function get end
end

# Compiler-recognized intrinsics for compiler plugins
"""
module CompilerPlugins
Implements a pair of functions `typeinf`/`typeinf_edge`. When the optimizer sees
a call to `typeinf`, it has license to instead call `typeinf_edge`, supplying the
current inference stack in `parent_frame` (but otherwise supplying the arguments
to `typeinf`). typeinf_edge will return the `CodeInstance` that `typeinf` would
have returned at runtime. The optimizer may perform a non-IPO replacement of
the call to `typeinf` by the result of `typeinf_edge`. In addition, the IPO-safe
fields of the `CodeInstance` may be propagated in IPO mode.
"""
module CompilerPlugins
"""
typeinf(owner, mi, source_mode)::CodeInstance
Return a `CodeInstance` for the given `mi` whose valid results include at
the least current tls world and satisfies the requirements of `source_mode`.
"""
function typeinf end

"""
typeinf_edge(owner, mi, parent_frame, world, abi_mode)::CodeInstance
"""
function typeinf_edge end
end

end
35 changes: 31 additions & 4 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -1578,11 +1578,38 @@ JL_CALLABLE(jl_f_invoke)
JL_NARGSV(invoke, 2);
jl_value_t *argtypes = args[1];
JL_GC_PUSH1(&argtypes);
if (!jl_is_tuple_type(jl_unwrap_unionall(args[1])))
jl_value_t *res = NULL;
if (jl_is_tuple_type(jl_unwrap_unionall(args[1]))) {
if (!jl_tuple_isa(&args[2], nargs - 2, (jl_datatype_t*)argtypes))
jl_type_error("invoke: argument type error", argtypes, jl_f_tuple(NULL, &args[2], nargs - 2));
res = jl_gf_invoke(argtypes, args[0], &args[2], nargs - 1);
}
else if (jl_is_code_instance(args[1])) {
jl_code_instance_t *codeinst = (jl_code_instance_t*)args[1];
jl_callptr_t invoke = jl_atomic_load_acquire(&codeinst->invoke);
if (jl_atomic_load_relaxed(&codeinst->min_world) > jl_current_task->world_age ||
jl_current_task->world_age > jl_atomic_load_relaxed(&codeinst->max_world)) {
jl_error("invoke: CodeInstance not valid for this world");
}
if (jl_tuple1_isa(args[0], nargs == 2 ? NULL : &args[2], nargs - 2, (jl_datatype_t*)codeinst->def->specTypes)) {
jl_type_error("invoke: argument type error", codeinst->def->specTypes, jl_f_tuple(args[0], &args[2], nargs - 2));
}
if (!invoke) {
jl_compile_codeinst(codeinst);
invoke = jl_atomic_load_acquire(&codeinst->invoke);
}
if (invoke) {
res = invoke(args[0], nargs == 2 ? NULL : &args[2], nargs - 2, codeinst);
} else {
if (codeinst->owner != jl_nothing) {
jl_error("Failed to invoke or compile external codeinst");
}
res = jl_invoke(args[0], nargs == 2 ? NULL : &args[2], nargs - 2, codeinst->def);
}
}
else {
jl_type_error("invoke", (jl_value_t*)jl_anytuple_type_type, args[1]);
if (!jl_tuple_isa(&args[2], nargs - 2, (jl_datatype_t*)argtypes))
jl_type_error("invoke: argument type error", argtypes, jl_f_tuple(NULL, &args[2], nargs - 2));
jl_value_t *res = jl_gf_invoke(argtypes, args[0], &args[2], nargs - 1);
}
JL_GC_POP();
return res;
}
Expand Down
24 changes: 22 additions & 2 deletions src/interpreter.c
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,28 @@ static jl_value_t *do_invoke(jl_value_t **args, size_t nargs, interpreter_state
argv[i-1] = eval_value(args[i], s);
jl_value_t *c = args[0];
assert(jl_is_code_instance(c) || jl_is_method_instance(c));
jl_method_instance_t *meth = jl_is_method_instance(c) ? (jl_method_instance_t*)c : ((jl_code_instance_t*)c)->def;
jl_value_t *result = jl_invoke(argv[0], nargs == 2 ? NULL : &argv[1], nargs - 2, meth);
jl_value_t *result = NULL;
if (jl_is_code_instance(c)) {
jl_code_instance_t *codeinst = (jl_code_instance_t*)c;
assert(jl_atomic_load_relaxed(&codeinst->min_world) <= jl_current_task->world_age &&
jl_current_task->world_age <= jl_atomic_load_relaxed(&codeinst->max_world));
jl_callptr_t invoke = jl_atomic_load_acquire(&codeinst->invoke);
if (!invoke) {
jl_compile_codeinst(codeinst);
invoke = jl_atomic_load_acquire(&codeinst->invoke);
}
if (invoke) {
result = invoke(argv[0], nargs == 2 ? NULL : &argv[1], nargs - 2, codeinst);

} else {
if (codeinst->owner != jl_nothing) {
jl_error("Failed to invoke or compile external codeinst");
}
result = jl_invoke(argv[0], nargs == 2 ? NULL : &argv[1], nargs - 2, codeinst->def);
}
} else {
result = jl_invoke(argv[0], nargs == 2 ? NULL : &argv[1], nargs - 2, (jl_method_instance_t*)c);
}
JL_GC_POP();
return result;
}
Expand Down

0 comments on commit 6521b62

Please sign in to comment.