Skip to content

Commit

Permalink
Fix issues in jitlayers and aotcompile
Browse files Browse the repository at this point in the history
  • Loading branch information
qinsoon committed Jun 7, 2024
1 parent a223989 commit c129682
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 20 deletions.
40 changes: 30 additions & 10 deletions src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ static void makeSafeName(GlobalObject &G)
G.setName(StringRef(SafeName.data(), SafeName.size()));
}

static void jl_ci_cache_lookup(const jl_cgparams_t &cgparams, jl_method_instance_t *mi, size_t world, jl_code_instance_t **ci_out, jl_code_info_t **src_out)
static void jl_ci_cache_lookup(const jl_cgparams_t &cgparams, jl_method_instance_t *mi JL_REQUIRE_PIN, size_t world, jl_code_instance_t **ci_out, jl_code_info_t **src_out)
{
++CICacheLookups;
jl_value_t *ci = cgparams.lookup(mi, world, world);
Expand Down Expand Up @@ -273,6 +273,7 @@ void replaceUsesWithLoad(Function &F, function_ref<GlobalVariable *(Instruction
extern "C" JL_DLLEXPORT
void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvmmod, const jl_cgparams_t *cgparams, int _policy, int _imaging_mode, int _external_linkage, size_t _world)
{
PTR_PIN(methods);
++CreateNativeCalls;
CreateNativeMax.updateMax(jl_array_len(methods));
if (cgparams == NULL)
Expand Down Expand Up @@ -320,12 +321,16 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
// each item in this list is either a MethodInstance indicating something
// to compile, or an svec(rettype, sig) describing a C-callable alias to create.
jl_value_t *item = jl_array_ptr_ref(methods, i);
PTR_PIN(item);
if (jl_is_simplevector(item)) {
if (worlds == 1)
jl_compile_extern_c(wrap(&clone), &params, NULL, jl_svecref(item, 0), jl_svecref(item, 1));
PTR_UNPIN(item);
continue;
}
PTR_UNPIN(item);
mi = (jl_method_instance_t*)item;
PTR_PIN(mi);
src = NULL;
// if this method is generally visible to the current compilation world,
// and this is either the primary world, or not applicable in the primary world
Expand All @@ -337,20 +342,24 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
if (src && !emitted.count(codeinst)) {
// now add it to our compilation results
JL_GC_PROMISE_ROOTED(codeinst->rettype);
PTR_PIN(codeinst->rettype);
orc::ThreadSafeModule result_m = jl_create_ts_module(name_from_method_instance(codeinst->def),
params.tsctx, params.imaging,
clone.getModuleUnlocked()->getDataLayout(),
Triple(clone.getModuleUnlocked()->getTargetTriple()));
jl_llvm_functions_t decls = jl_emit_code(result_m, mi, src, codeinst->rettype, params);
PTR_UNPIN(codeinst->rettype);
if (result_m)
emitted[codeinst] = {std::move(result_m), std::move(decls)};
}
}
PTR_UNPIN(mi);
}

// finally, make sure all referenced methods also get compiled or fixed up
jl_compile_workqueue(emitted, *clone.getModuleUnlocked(), params, policy);
}
PTR_UNPIN(methods);
JL_UNLOCK(&jl_codegen_lock); // Might GC
JL_GC_POP();

Expand Down Expand Up @@ -1048,31 +1057,38 @@ void jl_get_llvmf_defn_impl(jl_llvmf_dump_t* dump, jl_method_instance_t *mi, siz
return;
}

jl_method_t *method = mi->def.method;
PTR_PIN(mi);
PTR_PIN(method);
// get the source code for this function
jl_value_t *jlrettype = (jl_value_t*)jl_any_type;
jl_code_info_t *src = NULL;
JL_GC_PUSH2(&src, &jlrettype);
if (jl_is_method(mi->def.method) && mi->def.method->source != NULL && jl_ir_flag_inferred((jl_array_t*)mi->def.method->source)) {
src = (jl_code_info_t*)mi->def.method->source;
if (jl_is_method(method) && method->source != NULL && jl_ir_flag_inferred((jl_array_t*)method->source)) {
src = (jl_code_info_t*)method->source;
if (src && !jl_is_code_info(src))
src = jl_uncompress_ir(mi->def.method, NULL, (jl_array_t*)src);
src = jl_uncompress_ir(method, NULL, (jl_array_t*)src);
} else {
jl_value_t *ci = jl_rettype_inferred(mi, world, world);
if (ci != jl_nothing) {
jl_code_instance_t *codeinst = (jl_code_instance_t*)ci;
src = (jl_code_info_t*)jl_atomic_load_relaxed(&codeinst->inferred);
if ((jl_value_t*)src != jl_nothing && !jl_is_code_info(src) && jl_is_method(mi->def.method))
src = jl_uncompress_ir(mi->def.method, codeinst, (jl_array_t*)src);
if ((jl_value_t*)src != jl_nothing && !jl_is_code_info(src) && jl_is_method(method)) {
PTR_PIN(codeinst);
src = jl_uncompress_ir(method, codeinst, (jl_array_t*)src);
PTR_UNPIN(codeinst);
}
jlrettype = codeinst->rettype;

}
if (!src || (jl_value_t*)src == jl_nothing) {
src = jl_type_infer(mi, world, 0);
if (src)
jlrettype = src->rettype;
else if (jl_is_method(mi->def.method)) {
src = mi->def.method->generator ? jl_code_for_staged(mi) : (jl_code_info_t*)mi->def.method->source;
if (src && !jl_is_code_info(src) && jl_is_method(mi->def.method))
src = jl_uncompress_ir(mi->def.method, NULL, (jl_array_t*)src);
else if (jl_is_method(method)) {
src = method->generator ? jl_code_for_staged(mi) : (jl_code_info_t*)method->source;
if (src && !jl_is_code_info(src) && jl_is_method(method))
src = jl_uncompress_ir(method, NULL, (jl_array_t*)src);
}
// TODO: use mi->uninferred
}
Expand Down Expand Up @@ -1132,10 +1148,14 @@ void jl_get_llvmf_defn_impl(jl_llvmf_dump_t* dump, jl_method_instance_t *mi, siz
if (F) {
dump->TSM = wrap(new orc::ThreadSafeModule(std::move(m)));
dump->F = wrap(F);
PTR_UNPIN(mi);
PTR_UNPIN(method);
return;
}
}

const char *mname = name_from_method_instance(mi);
PTR_UNPIN(mi);
PTR_UNPIN(method);
jl_errorf("unable to compile source for function %s", mname);
}
62 changes: 52 additions & 10 deletions src/jitlayers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ static jl_callptr_t _jl_compile_codeinst(
size_t world,
orc::ThreadSafeContext context)
{
PTR_PIN(codeinst);
PTR_PIN(src);
// caller must hold codegen_lock
// and have disabled finalizers
uint64_t start_time = 0;
Expand Down Expand Up @@ -246,6 +248,7 @@ static jl_callptr_t _jl_compile_codeinst(
IndirectCodeinsts += emitted.size() - 1;
}
JL_TIMING(LLVM_MODULE_FINISH);
PTR_UNPIN(src);

for (auto &def : emitted) {
jl_code_instance_t *this_code = def.first;
Expand Down Expand Up @@ -307,6 +310,7 @@ static jl_callptr_t _jl_compile_codeinst(
jl_printf(stream, "\"\n");
}
}
PTR_UNPIN(codeinst);
return fptr;
}

Expand All @@ -316,6 +320,8 @@ const char *jl_generate_ccallable(LLVMOrcThreadSafeModuleRef llvmmod, void *sysi
extern "C" JL_DLLEXPORT
int jl_compile_extern_c_impl(LLVMOrcThreadSafeModuleRef llvmmod, void *p, void *sysimg, jl_value_t *declrt, jl_value_t *sigt)
{
PTR_PIN(declrt);
PTR_PIN(sigt);
auto ct = jl_current_task;
ct->reentrant_timing++;
uint64_t compiler_start_time = 0;
Expand All @@ -339,6 +345,8 @@ int jl_compile_extern_c_impl(LLVMOrcThreadSafeModuleRef llvmmod, void *p, void *
pparams = &params;
assert(pparams->tsctx.getContext() == into->getContext().getContext());
const char *name = jl_generate_ccallable(wrap(into), sysimg, declrt, sigt, *pparams);
PTR_UNPIN(declrt);
PTR_UNPIN(sigt);
bool success = true;
if (!sysimg) {
if (jl_ExecutionEngine->getGlobalValueAddress(name)) {
Expand Down Expand Up @@ -368,13 +376,18 @@ int jl_compile_extern_c_impl(LLVMOrcThreadSafeModuleRef llvmmod, void *p, void *
extern "C" JL_DLLEXPORT
void jl_extern_c_impl(jl_value_t *declrt, jl_tupletype_t *sigt)
{
PTR_PIN(declrt);
PTR_PIN(sigt);
jl_svec_t *params = ((jl_datatype_t*)(sigt))->parameters;
PTR_PIN(params);
// validate arguments. try to do as many checks as possible here to avoid
// throwing errors later during codegen.
JL_TYPECHK(@ccallable, type, declrt);
if (!jl_is_tuple_type(sigt))
jl_type_error("@ccallable", (jl_value_t*)jl_anytuple_type_type, (jl_value_t*)sigt);
// check that f is a guaranteed singleton type
jl_datatype_t *ft = (jl_datatype_t*)jl_tparam0(sigt);
PTR_PIN(ft);
if (!jl_is_datatype(ft) || ft->instance == NULL)
jl_error("@ccallable: function object must be a singleton");

Expand All @@ -385,12 +398,13 @@ void jl_extern_c_impl(jl_value_t *declrt, jl_tupletype_t *sigt)
jl_error("@ccallable: return type doesn't correspond to a C type");

// validate method signature
size_t i, nargs = jl_nparams(sigt);
size_t i, nargs = jl_svec_len(params);
for (i = 1; i < nargs; i++) {
jl_value_t *ati = jl_tparam(sigt, i);
jl_value_t *ati = jl_svecref(params, i);
if (!jl_is_concrete_type(ati) || jl_is_kind(ati) || !jl_type_mappable_to_c(ati))
jl_error("@ccallable: argument types must be concrete");
}
PTR_UNPIN(params);

// save a record of this so that the alias is generated when we write an object file
jl_method_t *meth = (jl_method_t*)jl_methtable_lookup(ft->name->mt, (jl_value_t*)sigt, jl_atomic_load_acquire(&jl_world_counter));
Expand All @@ -403,6 +417,9 @@ void jl_extern_c_impl(jl_value_t *declrt, jl_tupletype_t *sigt)

// create the alias in the current runtime environment
int success = jl_compile_extern_c(NULL, NULL, NULL, declrt, (jl_value_t*)sigt);
PTR_UNPIN(declrt);
PTR_UNPIN(sigt);
PTR_UNPIN(ft);
if (!success)
jl_error("@ccallable was already defined for this method name");
}
Expand All @@ -411,6 +428,9 @@ void jl_extern_c_impl(jl_value_t *declrt, jl_tupletype_t *sigt)
extern "C" JL_DLLEXPORT
jl_code_instance_t *jl_generate_fptr_impl(jl_method_instance_t *mi JL_PROPAGATES_ROOT, size_t world)
{
PTR_PIN(mi);
jl_method_t * method = mi->def.method;
PTR_PIN(method);
auto ct = jl_current_task;
ct->reentrant_timing++;
uint64_t compiler_start_time = 0;
Expand All @@ -425,19 +445,20 @@ jl_code_instance_t *jl_generate_fptr_impl(jl_method_instance_t *mi JL_PROPAGATES
jl_value_t *ci = jl_rettype_inferred(mi, world, world);
jl_code_instance_t *codeinst = (ci == jl_nothing ? NULL : (jl_code_instance_t*)ci);
if (codeinst) {
PTR_PIN(codeinst);
src = (jl_code_info_t*)jl_atomic_load_relaxed(&codeinst->inferred);
if ((jl_value_t*)src == jl_nothing)
src = NULL;
else if (jl_is_method(mi->def.method))
src = jl_uncompress_ir(mi->def.method, codeinst, (jl_array_t*)src);
else if (jl_is_method(method))
src = jl_uncompress_ir(method, codeinst, (jl_array_t*)src);
}
else {
// identify whether this is an invalidated method that is being recompiled
is_recompile = jl_atomic_load_relaxed(&mi->cache) != NULL;
}
if (src == NULL && jl_is_method(mi->def.method) &&
jl_symbol_name(mi->def.method->name)[0] != '@') {
if (mi->def.method->source != jl_nothing) {
if (src == NULL && jl_is_method(method) &&
jl_symbol_name(method->name)[0] != '@') {
if (method->source != jl_nothing) {
// If the caller didn't provide the source and IR is available,
// see if it is inferred, or try to infer it for ourself.
// (but don't bother with typeinf on macros or toplevel thunks)
Expand All @@ -446,22 +467,28 @@ jl_code_instance_t *jl_generate_fptr_impl(jl_method_instance_t *mi JL_PROPAGATES
}
jl_code_instance_t *compiled = jl_method_compiled(mi, world);
if (compiled) {
if (codeinst) PTR_UNPIN(codeinst);
codeinst = compiled;
PTR_PIN(codeinst);
}
else if (src && jl_is_code_info(src)) {
if (!codeinst) {
codeinst = jl_get_method_inferred(mi, src->rettype, src->min_world, src->max_world);
PTR_PIN(codeinst);
if (src->inferred) {
jl_value_t *null = nullptr;
jl_atomic_cmpswap_relaxed(&codeinst->inferred, &null, jl_nothing);
}
}
++SpecFPtrCount;
_jl_compile_codeinst(codeinst, src, world, *jl_ExecutionEngine->getContext());
if (jl_atomic_load_relaxed(&codeinst->invoke) == NULL)
if (jl_atomic_load_relaxed(&codeinst->invoke) == NULL) {
if (codeinst) PTR_UNPIN(codeinst);
codeinst = NULL;
}
}
else {
if (codeinst) PTR_UNPIN(codeinst);
codeinst = NULL;
}
JL_UNLOCK(&jl_codegen_lock);
Expand All @@ -473,6 +500,9 @@ jl_code_instance_t *jl_generate_fptr_impl(jl_method_instance_t *mi JL_PROPAGATES
jl_atomic_fetch_add_relaxed(&jl_cumulative_compile_time, t_comp);
}
JL_GC_POP();
PTR_UNPIN(mi);
PTR_UNPIN(method);
if(codeinst) PTR_UNPIN(codeinst);
return codeinst;
}

Expand All @@ -482,6 +512,7 @@ void jl_generate_fptr_for_unspecialized_impl(jl_code_instance_t *unspec)
if (jl_atomic_load_relaxed(&unspec->invoke) != NULL) {
return;
}
PTR_PIN(unspec);
auto ct = jl_current_task;
ct->reentrant_timing++;
uint64_t compiler_start_time = 0;
Expand All @@ -494,6 +525,7 @@ void jl_generate_fptr_for_unspecialized_impl(jl_code_instance_t *unspec)
JL_GC_PUSH1(&src);
jl_method_t *def = unspec->def->def.method;
if (jl_is_method(def)) {
PTR_PIN(def);
src = (jl_code_info_t*)def->source;
if (src == NULL) {
// TODO: this is wrong
Expand All @@ -503,6 +535,7 @@ void jl_generate_fptr_for_unspecialized_impl(jl_code_instance_t *unspec)
}
if (src && (jl_value_t*)src != jl_nothing)
src = jl_uncompress_ir(def, NULL, (jl_array_t*)src);
PTR_UNPIN(def);
}
else {
src = (jl_code_info_t*)unspec->def->uninferred;
Expand All @@ -515,6 +548,7 @@ void jl_generate_fptr_for_unspecialized_impl(jl_code_instance_t *unspec)
jl_atomic_cmpswap(&unspec->invoke, &null, jl_fptr_interpret_call_addr);
JL_GC_POP();
}
PTR_UNPIN(unspec);
JL_UNLOCK(&jl_codegen_lock); // Might GC
if (!--ct->reentrant_timing && measure_compile_time_enabled) {
auto end = jl_hrtime();
Expand All @@ -528,12 +562,15 @@ extern "C" JL_DLLEXPORT
jl_value_t *jl_dump_method_asm_impl(jl_method_instance_t *mi, size_t world,
char raw_mc, char getwrapper, const char* asm_variant, const char *debuginfo, char binary)
{
PTR_PIN(mi);
// printing via disassembly
jl_code_instance_t *codeinst = jl_generate_fptr(mi, world);
if (codeinst) {
uintptr_t fptr = (uintptr_t)jl_atomic_load_acquire(&codeinst->invoke);
if (getwrapper)
if (getwrapper) {
PTR_UNPIN(mi);
return jl_dump_fptr_asm(fptr, raw_mc, asm_variant, debuginfo, binary);
}
uintptr_t specfptr = (uintptr_t)jl_atomic_load_relaxed(&codeinst->specptr.fptr);
if (fptr == (uintptr_t)jl_fptr_const_return_addr && specfptr == 0) {
// normally we prevent native code from being generated for these functions,
Expand All @@ -545,6 +582,7 @@ jl_value_t *jl_dump_method_asm_impl(jl_method_instance_t *mi, size_t world,
uint8_t measure_compile_time_enabled = jl_atomic_load_relaxed(&jl_measure_compile_time_enabled);
if (measure_compile_time_enabled)
compiler_start_time = jl_hrtime();
PTR_PIN(codeinst);
JL_LOCK(&jl_codegen_lock); // also disables finalizers, to prevent any unexpected recursion
specfptr = (uintptr_t)jl_atomic_load_relaxed(&codeinst->specptr.fptr);
if (specfptr == 0) {
Expand All @@ -569,19 +607,23 @@ jl_value_t *jl_dump_method_asm_impl(jl_method_instance_t *mi, size_t world,
}
JL_GC_POP();
}
PTR_UNPIN(codeinst);
JL_UNLOCK(&jl_codegen_lock);
if (!--ct->reentrant_timing && measure_compile_time_enabled) {
auto end = jl_hrtime();
jl_atomic_fetch_add_relaxed(&jl_cumulative_compile_time, end - compiler_start_time);
}
}
if (specfptr != 0)
if (specfptr != 0) {
PTR_UNPIN(mi);
return jl_dump_fptr_asm(specfptr, raw_mc, asm_variant, debuginfo, binary);
}
}

// whatever, that didn't work - use the assembler output instead
jl_llvmf_dump_t llvmf_dump;
jl_get_llvmf_defn(&llvmf_dump, mi, world, getwrapper, true, jl_default_cgparams);
PTR_UNPIN(mi);
if (!llvmf_dump.F)
return jl_an_empty_string;
return jl_dump_function_asm(&llvmf_dump, raw_mc, asm_variant, debuginfo, binary);
Expand Down

0 comments on commit c129682

Please sign in to comment.