Skip to content

Commit

Permalink
precompile: do better union split and concrete compilation search (#5…
Browse files Browse the repository at this point in the history
…6496)

This fixes some bugs that prevent compile-all from working correctly at
all, and uses more of it for normal compile. Increases sysimg size from
about 140 to 170 MB of data and 11 to 15 MB of code
  • Loading branch information
vtjnash authored Nov 13, 2024
2 parents 072d9d1 + 882f940 commit aa05c98
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 90 deletions.
64 changes: 34 additions & 30 deletions src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,8 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
// Const returns do not do codegen, but juliac inspects codegen results so make a dummy fvar entry to represent it
if (jl_options.trim != JL_TRIM_NO && jl_atomic_load_relaxed(&codeinst->invoke) == jl_fptr_const_return_addr) {
data->jl_fvar_map[codeinst] = std::make_tuple((uint32_t)-3, (uint32_t)-3);
} else {
}
else {
JL_GC_PROMISE_ROOTED(codeinst->rettype);
orc::ThreadSafeModule result_m = jl_create_ts_module(name_from_method_instance(codeinst->def),
params.tsctx, clone.getModuleUnlocked()->getDataLayout(),
Expand Down Expand Up @@ -609,6 +610,9 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
else if (func == "jl_fptr_sparam") {
func_id = -2;
}
else if (decls.functionObject == "jl_f_opaque_closure_call") {
func_id = -4;
}
else {
//Safe b/c context is locked by params
data->jl_sysimg_fvars.push_back(cast<Function>(clone.getModuleUnlocked()->getNamedValue(func)));
Expand Down Expand Up @@ -896,19 +900,18 @@ struct Partition {
size_t weight;
};

static bool canPartition(const GlobalValue &G) {
if (auto F = dyn_cast<Function>(&G)) {
if (F->hasFnAttribute(Attribute::AlwaysInline))
return false;
}
return true;
static bool canPartition(const Function &F)
{
return !F.hasFnAttribute(Attribute::AlwaysInline);
}

static inline bool verify_partitioning(const SmallVectorImpl<Partition> &partitions, const Module &M, size_t fvars_size, size_t gvars_size) {
static inline bool verify_partitioning(const SmallVectorImpl<Partition> &partitions, const Module &M, DenseMap<GlobalValue *, unsigned> &fvars, DenseMap<GlobalValue *, unsigned> &gvars) {
bool bad = false;
#ifndef JL_NDEBUG
SmallVector<uint32_t, 0> fvars(fvars_size);
SmallVector<uint32_t, 0> gvars(gvars_size);
size_t fvars_size = fvars.size();
size_t gvars_size = gvars.size();
SmallVector<uint32_t, 0> fvars_partition(fvars_size);
SmallVector<uint32_t, 0> gvars_partition(gvars_size);
StringMap<uint32_t> GVNames;
for (uint32_t i = 0; i < partitions.size(); i++) {
for (auto &name : partitions[i].globals) {
Expand All @@ -919,18 +922,18 @@ static inline bool verify_partitioning(const SmallVectorImpl<Partition> &partiti
GVNames[name.getKey()] = i;
}
for (auto &fvar : partitions[i].fvars) {
if (fvars[fvar.second] != 0) {
if (fvars_partition[fvar.second] != 0) {
bad = true;
dbgs() << "Duplicate fvar " << fvar.first() << " in partitions " << i << " and " << fvars[fvar.second] - 1 << "\n";
dbgs() << "Duplicate fvar " << fvar.first() << " in partitions " << i << " and " << fvars_partition[fvar.second] - 1 << "\n";
}
fvars[fvar.second] = i+1;
fvars_partition[fvar.second] = i+1;
}
for (auto &gvar : partitions[i].gvars) {
if (gvars[gvar.second] != 0) {
if (gvars_partition[gvar.second] != 0) {
bad = true;
dbgs() << "Duplicate gvar " << gvar.first() << " in partitions " << i << " and " << gvars[gvar.second] - 1 << "\n";
dbgs() << "Duplicate gvar " << gvar.first() << " in partitions " << i << " and " << gvars_partition[gvar.second] - 1 << "\n";
}
gvars[gvar.second] = i+1;
gvars_partition[gvar.second] = i+1;
}
}
for (auto &GV : M.global_values()) {
Expand All @@ -941,13 +944,6 @@ static inline bool verify_partitioning(const SmallVectorImpl<Partition> &partiti
}
} else {
// Local global values are not partitioned
if (!canPartition(GV)) {
if (GVNames.count(GV.getName())) {
bad = true;
dbgs() << "Shouldn't have partitioned " << GV.getName() << ", but is in partition " << GVNames[GV.getName()] << "\n";
}
continue;
}
if (!GVNames.count(GV.getName())) {
bad = true;
dbgs() << "Global " << GV << " not in any partition\n";
Expand All @@ -967,13 +963,14 @@ static inline bool verify_partitioning(const SmallVectorImpl<Partition> &partiti
}
}
for (uint32_t i = 0; i < fvars_size; i++) {
if (fvars[i] == 0) {
if (fvars_partition[i] == 0) {
auto gv = find_if(fvars.begin(), fvars.end(), [i](auto var) { return var.second == i; });
bad = true;
dbgs() << "fvar " << i << " not in any partition\n";
dbgs() << "fvar " << gv->first->getName() << " at " << i << " not in any partition\n";
}
}
for (uint32_t i = 0; i < gvars_size; i++) {
if (gvars[i] == 0) {
if (gvars_partition[i] == 0) {
bad = true;
dbgs() << "gvar " << i << " not in any partition\n";
}
Expand Down Expand Up @@ -1035,8 +1032,6 @@ static SmallVector<Partition, 32> partitionModule(Module &M, unsigned threads) {
for (auto &G : M.global_values()) {
if (G.isDeclaration())
continue;
if (!canPartition(G))
continue;
// Currently ccallable global aliases have extern linkage, we only want to make the
// internally linked functions/global variables extern+hidden
if (G.hasLocalLinkage()) {
Expand All @@ -1045,7 +1040,8 @@ static SmallVector<Partition, 32> partitionModule(Module &M, unsigned threads) {
}
if (auto F = dyn_cast<Function>(&G)) {
partitioner.make(&G, getFunctionWeight(*F).weight);
} else {
}
else {
partitioner.make(&G, 1);
}
}
Expand Down Expand Up @@ -1117,7 +1113,9 @@ static SmallVector<Partition, 32> partitionModule(Module &M, unsigned threads) {
}
}

bool verified = verify_partitioning(partitions, M, fvars.size(), gvars.size());
bool verified = verify_partitioning(partitions, M, fvars, gvars);
if (!verified)
M.dump();
assert(verified && "Partitioning failed to partition globals correctly");
(void) verified;

Expand Down Expand Up @@ -1371,6 +1369,12 @@ static void materializePreserved(Module &M, Partition &partition) {
continue;
if (Preserve.contains(&F))
continue;
if (!canPartition(F)) {
F.setLinkage(GlobalValue::AvailableExternallyLinkage);
F.setVisibility(GlobalValue::HiddenVisibility);
F.setDSOLocal(true);
continue;
}
F.deleteBody();
F.setLinkage(GlobalValue::ExternalLinkage);
F.setVisibility(GlobalValue::HiddenVisibility);
Expand Down
8 changes: 7 additions & 1 deletion src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -3188,6 +3188,12 @@ JL_DLLEXPORT void jl_compile_method_instance(jl_method_instance_t *mi, jl_tuplet
}
}

JL_DLLEXPORT void jl_compile_method_sig(jl_method_t *m, jl_value_t *types, jl_svec_t *env, size_t world)
{
jl_method_instance_t *mi = jl_specializations_get_linfo(m, types, env);
jl_compile_method_instance(mi, NULL, world);
}

JL_DLLEXPORT int jl_compile_hint(jl_tupletype_t *types)
{
size_t world = jl_atomic_load_acquire(&jl_world_counter);
Expand All @@ -3197,7 +3203,7 @@ JL_DLLEXPORT int jl_compile_hint(jl_tupletype_t *types)
if (mi == NULL)
return 0;
JL_GC_PROMISE_ROOTED(mi);
jl_compile_method_instance(mi, types, world);
jl_compile_method_instance(mi, NULL, world);
return 1;
}

Expand Down
1 change: 1 addition & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,7 @@ JL_DLLEXPORT jl_module_t *jl_debuginfo_module1(jl_value_t *debuginfo_def) JL_NOT
JL_DLLEXPORT const char *jl_debuginfo_name(jl_value_t *func) JL_NOTSAFEPOINT;

JL_DLLEXPORT void jl_compile_method_instance(jl_method_instance_t *mi, jl_tupletype_t *types, size_t world);
JL_DLLEXPORT void jl_compile_method_sig(jl_method_t *m, jl_value_t *types, jl_svec_t *sparams, size_t world);
JL_DLLEXPORT int jl_compile_hint(jl_tupletype_t *types);
JL_DLLEXPORT int jl_add_entrypoint(jl_tupletype_t *types);
jl_code_info_t *jl_code_for_interpreter(jl_method_instance_t *lam JL_PROPAGATES_ROOT, size_t world);
Expand Down
130 changes: 72 additions & 58 deletions src/precompile_utils.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// f{<:Union{...}}(...) is a common pattern
// and expanding the Union may give a leaf function
static void _compile_all_tvar_union(jl_value_t *methsig)
// This file is a part of Julia. License is MIT: https://julialang.org/license

// f(...) where {T<:Union{...}} is a common pattern
// and expanding the Union may give some leaf functions
static int _compile_all_tvar_union(jl_value_t *methsig)
{
int tvarslen = jl_subtype_env_size(methsig);
jl_value_t *sigbody = methsig;
Expand All @@ -13,86 +15,94 @@ static void _compile_all_tvar_union(jl_value_t *methsig)
assert(jl_is_unionall(sigbody));
idx[i] = 0;
env[2 * i] = (jl_value_t*)((jl_unionall_t*)sigbody)->var;
env[2 * i + 1] = jl_bottom_type; // initialize the list with Union{}, since T<:Union{} is always a valid option
jl_value_t *tv = env[2 * i];
while (jl_is_typevar(tv))
tv = ((jl_tvar_t*)tv)->ub;
if (jl_is_abstracttype(tv) && !jl_is_type_type(tv)) {
JL_GC_POP();
return 0; // Any as TypeVar is common and not useful here to try to analyze further
}
env[2 * i + 1] = tv;
sigbody = ((jl_unionall_t*)sigbody)->body;
}

for (i = 0; i < tvarslen; /* incremented by inner loop */) {
jl_value_t **sig = &roots[0];
int all = 1;
int incr = 0;
while (!incr) {
for (i = 0, incr = 1; i < tvarslen; i++) {
jl_value_t *tv = env[2 * i];
while (jl_is_typevar(tv))
tv = ((jl_tvar_t*)tv)->ub;
if (jl_is_uniontype(tv)) {
size_t l = jl_count_union_components(tv);
size_t j = idx[i];
env[2 * i + 1] = jl_nth_union_component(tv, j);
++j;
if (incr) {
if (j == l) {
idx[i] = 0;
}
else {
idx[i] = j;
incr = 0;
}
}
}
}
jl_value_t *sig = NULL;
JL_TRY {
// TODO: wrap in UnionAll for each tvar in env[2*i + 1] ?
// currently doesn't matter much, since jl_compile_hint doesn't work on abstract types
*sig = (jl_value_t*)jl_instantiate_type_with(sigbody, env, tvarslen);
sig = (jl_value_t*)jl_instantiate_type_with(sigbody, env, tvarslen);
}
JL_CATCH {
goto getnext; // sigh, we found an invalid type signature. should we warn the user?
}
if (!jl_has_concrete_subtype(*sig))
goto getnext; // signature wouldn't be callable / is invalid -- skip it
if (jl_is_concrete_type(*sig)) {
if (jl_compile_hint((jl_tupletype_t *)*sig))
goto getnext; // success
sig = NULL;
}

getnext:
for (i = 0; i < tvarslen; i++) {
jl_tvar_t *tv = (jl_tvar_t*)env[2 * i];
if (jl_is_uniontype(tv->ub)) {
size_t l = jl_count_union_components(tv->ub);
size_t j = idx[i];
if (j == l) {
env[2 * i + 1] = jl_bottom_type;
idx[i] = 0;
}
else {
jl_value_t *ty = jl_nth_union_component(tv->ub, j);
if (!jl_is_concrete_type(ty))
ty = (jl_value_t*)jl_new_typevar(tv->name, tv->lb, ty);
env[2 * i + 1] = ty;
idx[i] = j + 1;
break;
}
}
else {
env[2 * i + 1] = (jl_value_t*)tv;
}
if (sig) {
roots[0] = sig;
if (jl_is_datatype(sig) && jl_has_concrete_subtype(sig))
all = all && jl_compile_hint((jl_tupletype_t*)sig);
else
all = 0;
}
}
JL_GC_POP();
return all;
}

// f(::Union{...}, ...) is a common pattern
// and expanding the Union may give a leaf function
static void _compile_all_union(jl_value_t *sig)
static int _compile_all_union(jl_value_t *sig)
{
jl_tupletype_t *sigbody = (jl_tupletype_t*)jl_unwrap_unionall(sig);
size_t count_unions = 0;
size_t union_size = 1;
size_t i, l = jl_svec_len(sigbody->parameters);
jl_svec_t *p = NULL;
jl_value_t *methsig = NULL;

for (i = 0; i < l; i++) {
jl_value_t *ty = jl_svecref(sigbody->parameters, i);
if (jl_is_uniontype(ty))
++count_unions;
else if (ty == jl_bottom_type)
return; // why does this method exist?
else if (jl_is_datatype(ty) && !jl_has_free_typevars(ty) &&
((!jl_is_kind(ty) && ((jl_datatype_t*)ty)->isconcretetype) ||
((jl_datatype_t*)ty)->name == jl_type_typename))
return; // no amount of union splitting will make this a leaftype signature
if (jl_is_uniontype(ty)) {
count_unions += 1;
union_size *= jl_count_union_components(ty);
}
else if (jl_is_datatype(ty) &&
((!((jl_datatype_t*)ty)->isconcretetype || jl_is_kind(ty)) &&
((jl_datatype_t*)ty)->name != jl_type_typename))
return 0; // no amount of union splitting will make this a dispatch signature
}

if (count_unions == 0 || count_unions >= 6) {
_compile_all_tvar_union(sig);
return;
if (union_size <= 1 || union_size > 8) {
return _compile_all_tvar_union(sig);
}

int *idx = (int*)alloca(sizeof(int) * count_unions);
for (i = 0; i < count_unions; i++) {
idx[i] = 0;
}

int all = 1;
JL_GC_PUSH2(&p, &methsig);
int idx_ctr = 0, incr = 0;
while (!incr) {
Expand Down Expand Up @@ -122,10 +132,12 @@ static void _compile_all_union(jl_value_t *sig)
}
methsig = jl_apply_tuple_type(p, 1);
methsig = jl_rewrap_unionall(methsig, sig);
_compile_all_tvar_union(methsig);
if (!_compile_all_tvar_union(methsig))
all = 0;
}

JL_GC_POP();
return all;
}

static int compile_all_collect__(jl_typemap_entry_t *ml, void *env)
Expand All @@ -147,29 +159,32 @@ static int compile_all_collect_(jl_methtable_t *mt, void *env)
return 1;
}

static void jl_compile_all_defs(jl_array_t *mis)
static void jl_compile_all_defs(jl_array_t *mis, int all)
{
jl_array_t *allmeths = jl_alloc_vec_any(0);
JL_GC_PUSH1(&allmeths);

jl_foreach_reachable_mtable(compile_all_collect_, allmeths);

size_t world = jl_atomic_load_acquire(&jl_world_counter);
size_t i, l = jl_array_nrows(allmeths);
for (i = 0; i < l; i++) {
jl_method_t *m = (jl_method_t*)jl_array_ptr_ref(allmeths, i);
if (jl_is_datatype(m->sig) && jl_isa_compileable_sig((jl_tupletype_t*)m->sig, jl_emptysvec, m)) {
// method has a single compilable specialization, e.g. its definition
// signature is concrete. in this case we can just hint it.
jl_compile_hint((jl_tupletype_t*)m->sig);
jl_compile_method_sig(m, m->sig, jl_emptysvec, world);
}
else {
// first try to create leaf signatures from the signature declaration and compile those
_compile_all_union(m->sig);

// finally, compile a fully generic fallback that can work for all arguments
jl_method_instance_t *unspec = jl_get_unspecialized(m);
if (unspec)
jl_array_ptr_1d_push(mis, (jl_value_t*)unspec);
if (all) {
// finally, compile a fully generic fallback that can work for all arguments (even invoke)
jl_method_instance_t *unspec = jl_get_unspecialized(m);
if (unspec)
jl_array_ptr_1d_push(mis, (jl_value_t*)unspec);
}
}
}

Expand Down Expand Up @@ -273,8 +288,7 @@ static void *jl_precompile(int all)
// array of MethodInstances and ccallable aliases to include in the output
jl_array_t *m = jl_alloc_vec_any(0);
JL_GC_PUSH1(&m);
if (all)
jl_compile_all_defs(m);
jl_compile_all_defs(m, all);
jl_foreach_reachable_mtable(precompile_enq_all_specializations_, m);
void *native_code = jl_precompile_(m, 0);
JL_GC_POP();
Expand Down
Loading

2 comments on commit aa05c98

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing the daily package evaluation, I will reply here when finished:

@nanosoldier runtests(isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The package evaluation job you requested has completed - possible new issues were detected.
The full report is available.

Please sign in to comment.