Skip to content

Commit

Permalink
do union split and concrete compilation search
Browse files Browse the repository at this point in the history
  • Loading branch information
vtjnash committed Nov 7, 2024
1 parent 8593792 commit 79bea13
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 57 deletions.
8 changes: 7 additions & 1 deletion src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -3188,6 +3188,12 @@ JL_DLLEXPORT void jl_compile_method_instance(jl_method_instance_t *mi, jl_tuplet
}
}

JL_DLLEXPORT void jl_compile_method_sig(jl_method_t *m, jl_value_t *types, jl_svec_t *env, size_t world)
{
jl_method_instance_t *mi = jl_specializations_get_linfo(m, types, env);
jl_compile_method_instance(mi, NULL, world);
}

JL_DLLEXPORT int jl_compile_hint(jl_tupletype_t *types)
{
size_t world = jl_atomic_load_acquire(&jl_world_counter);
Expand All @@ -3197,7 +3203,7 @@ JL_DLLEXPORT int jl_compile_hint(jl_tupletype_t *types)
if (mi == NULL)
return 0;
JL_GC_PROMISE_ROOTED(mi);
jl_compile_method_instance(mi, types, world);
jl_compile_method_instance(mi, NULL, world);
return 1;
}

Expand Down
1 change: 1 addition & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,7 @@ JL_DLLEXPORT jl_module_t *jl_debuginfo_module1(jl_value_t *debuginfo_def) JL_NOT
JL_DLLEXPORT const char *jl_debuginfo_name(jl_value_t *func) JL_NOTSAFEPOINT;

JL_DLLEXPORT void jl_compile_method_instance(jl_method_instance_t *mi, jl_tupletype_t *types, size_t world);
JL_DLLEXPORT void jl_compile_method_sig(jl_method_t *m, jl_value_t *types, jl_svec_t *sparams, size_t world);
JL_DLLEXPORT int jl_compile_hint(jl_tupletype_t *types);
JL_DLLEXPORT int jl_add_entrypoint(jl_tupletype_t *types);
jl_code_info_t *jl_code_for_interpreter(jl_method_instance_t *lam JL_PROPAGATES_ROOT, size_t world);
Expand Down
125 changes: 69 additions & 56 deletions src/precompile_utils.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// f{<:Union{...}}(...) is a common pattern
// and expanding the Union may give a leaf function
static void _compile_all_tvar_union(jl_value_t *methsig)
// This file is a part of Julia. License is MIT: https://julialang.org/license

// f(...) where {T<:Union{...}} is a common pattern
// and expanding the Union may give some leaf functions
static int _compile_all_tvar_union(jl_value_t *methsig)
{
int tvarslen = jl_subtype_env_size(methsig);
jl_value_t *sigbody = methsig;
Expand All @@ -13,86 +15,93 @@ static void _compile_all_tvar_union(jl_value_t *methsig)
assert(jl_is_unionall(sigbody));
idx[i] = 0;
env[2 * i] = (jl_value_t*)((jl_unionall_t*)sigbody)->var;
env[2 * i + 1] = jl_bottom_type; // initialize the list with Union{}, since T<:Union{} is always a valid option
jl_value_t *tv = env[2 * i];
while (jl_is_typevar(tv))
tv = ((jl_tvar_t*)tv)->ub;
if (jl_is_abstracttype(tv) && !jl_is_type_type(tv)) {
JL_GC_POP();
return 0; // Any as TypeVar is common and not useful here to try to analyze further
}
env[2 * i + 1] = tv;
sigbody = ((jl_unionall_t*)sigbody)->body;
}

for (i = 0; i < tvarslen; /* incremented by inner loop */) {
int all = 1;
int incr = 0;
while (!incr) {
for (i = 0, incr = 1; i < tvarslen; i++) {
jl_value_t *tv = env[2 * i];
while (jl_is_typevar(tv))
tv = ((jl_tvar_t*)tv)->ub;
if (jl_is_uniontype(tv)) {
size_t l = jl_count_union_components(tv);
size_t j = idx[i];
env[2 * i + 1] = jl_nth_union_component(tv, j);
++j;
if (incr) {
if (j == l) {
idx[i] = 0;
}
else {
idx[i] = j;
incr = 0;
}
}
}
}
jl_value_t **sig = &roots[0];
JL_TRY {
// TODO: wrap in UnionAll for each tvar in env[2*i + 1] ?
// currently doesn't matter much, since jl_compile_hint doesn't work on abstract types
*sig = (jl_value_t*)jl_instantiate_type_with(sigbody, env, tvarslen);
}
JL_CATCH {
goto getnext; // sigh, we found an invalid type signature. should we warn the user?
}
if (!jl_has_concrete_subtype(*sig))
goto getnext; // signature wouldn't be callable / is invalid -- skip it
if (jl_is_concrete_type(*sig)) {
if (jl_compile_hint((jl_tupletype_t *)*sig))
goto getnext; // success
*sig = NULL;
}

getnext:
for (i = 0; i < tvarslen; i++) {
jl_tvar_t *tv = (jl_tvar_t*)env[2 * i];
if (jl_is_uniontype(tv->ub)) {
size_t l = jl_count_union_components(tv->ub);
size_t j = idx[i];
if (j == l) {
env[2 * i + 1] = jl_bottom_type;
idx[i] = 0;
}
else {
jl_value_t *ty = jl_nth_union_component(tv->ub, j);
if (!jl_is_concrete_type(ty))
ty = (jl_value_t*)jl_new_typevar(tv->name, tv->lb, ty);
env[2 * i + 1] = ty;
idx[i] = j + 1;
break;
}
}
else {
env[2 * i + 1] = (jl_value_t*)tv;
}
if (*sig) {
if (jl_is_datatype(*sig) && jl_has_concrete_subtype(*sig))
all = all && jl_compile_hint((jl_tupletype_t*)*sig);
else
all = 0;
}
}
JL_GC_POP();
return all;
}

// f(::Union{...}, ...) is a common pattern
// and expanding the Union may give a leaf function
static void _compile_all_union(jl_value_t *sig)
static int _compile_all_union(jl_value_t *sig)
{
jl_tupletype_t *sigbody = (jl_tupletype_t*)jl_unwrap_unionall(sig);
size_t count_unions = 0;
size_t union_size = 1;
size_t i, l = jl_svec_len(sigbody->parameters);
jl_svec_t *p = NULL;
jl_value_t *methsig = NULL;

for (i = 0; i < l; i++) {
jl_value_t *ty = jl_svecref(sigbody->parameters, i);
if (jl_is_uniontype(ty))
++count_unions;
else if (ty == jl_bottom_type)
return; // why does this method exist?
else if (jl_is_datatype(ty) && !jl_has_free_typevars(ty) &&
((!jl_is_kind(ty) && ((jl_datatype_t*)ty)->isconcretetype) ||
((jl_datatype_t*)ty)->name == jl_type_typename))
return; // no amount of union splitting will make this a leaftype signature
if (jl_is_uniontype(ty)) {
count_unions += 1;
union_size *= jl_count_union_components(ty);
}
else if (jl_is_datatype(ty) &&
((!((jl_datatype_t*)ty)->isconcretetype || jl_is_kind(ty)) &&
((jl_datatype_t*)ty)->name != jl_type_typename))
return 0; // no amount of union splitting will make this a dispatch signature
}

if (count_unions == 0 || count_unions >= 6) {
_compile_all_tvar_union(sig);
return;
if (union_size <= 1 || union_size > 8) {
return _compile_all_tvar_union(sig);
}

int *idx = (int*)alloca(sizeof(int) * count_unions);
for (i = 0; i < count_unions; i++) {
idx[i] = 0;
}

int all = 1;
JL_GC_PUSH2(&p, &methsig);
int idx_ctr = 0, incr = 0;
while (!incr) {
Expand Down Expand Up @@ -122,10 +131,12 @@ static void _compile_all_union(jl_value_t *sig)
}
methsig = jl_apply_tuple_type(p, 1);
methsig = jl_rewrap_unionall(methsig, sig);
_compile_all_tvar_union(methsig);
if (!_compile_all_tvar_union(methsig))
all = 0;
}

JL_GC_POP();
return all;
}

static int compile_all_collect__(jl_typemap_entry_t *ml, void *env)
Expand All @@ -147,29 +158,32 @@ static int compile_all_collect_(jl_methtable_t *mt, void *env)
return 1;
}

static void jl_compile_all_defs(jl_array_t *mis)
static void jl_compile_all_defs(jl_array_t *mis, int all)
{
jl_array_t *allmeths = jl_alloc_vec_any(0);
JL_GC_PUSH1(&allmeths);

jl_foreach_reachable_mtable(compile_all_collect_, allmeths);

size_t world = jl_atomic_load_acquire(&jl_world_counter);
size_t i, l = jl_array_nrows(allmeths);
for (i = 0; i < l; i++) {
jl_method_t *m = (jl_method_t*)jl_array_ptr_ref(allmeths, i);
if (jl_is_datatype(m->sig) && jl_isa_compileable_sig((jl_tupletype_t*)m->sig, jl_emptysvec, m)) {
// method has a single compilable specialization, e.g. its definition
// signature is concrete. in this case we can just hint it.
jl_compile_hint((jl_tupletype_t*)m->sig);
jl_compile_method_sig(m, m->sig, jl_emptysvec, world);
}
else {
// first try to create leaf signatures from the signature declaration and compile those
_compile_all_union(m->sig);

// finally, compile a fully generic fallback that can work for all arguments
jl_method_instance_t *unspec = jl_get_unspecialized(m);
if (unspec)
jl_array_ptr_1d_push(mis, (jl_value_t*)unspec);
if (all) {
// finally, compile a fully generic fallback that can work for all arguments (even invoke)
jl_method_instance_t *unspec = jl_get_unspecialized(m);
if (unspec)
jl_array_ptr_1d_push(mis, (jl_value_t*)unspec);
}
}
}

Expand Down Expand Up @@ -273,8 +287,7 @@ static void *jl_precompile(int all)
// array of MethodInstances and ccallable aliases to include in the output
jl_array_t *m = jl_alloc_vec_any(0);
JL_GC_PUSH1(&m);
if (all)
jl_compile_all_defs(m);
jl_compile_all_defs(m, all);
jl_foreach_reachable_mtable(precompile_enq_all_specializations_, m);
void *native_code = jl_precompile_(m, 0);
JL_GC_POP();
Expand Down

0 comments on commit 79bea13

Please sign in to comment.