Skip to content

Commit

Permalink
WIP: Add ABIOverwrite type for def field
Browse files Browse the repository at this point in the history
Together with #54899, this PR is intending to replicate the functionality
of #54373, which allowed particular specializations to have a different
ABI signature than what would be suggested by the MethodInstance's
`specTypes` field. This PR handles that by adding a special `ABIOverwrite`
type, which, when placed in the `def` field of a `CodeInstance`
instructs the system to use the given signature instead.
  • Loading branch information
Keno committed Dec 12, 2024
1 parent aa05c98 commit 31ee1b4
Show file tree
Hide file tree
Showing 20 changed files with 115 additions and 73 deletions.
6 changes: 3 additions & 3 deletions Compiler/src/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ function _add_edges_impl(edges::Vector{Any}, info::MethodMatchInfo, mi_edge::Boo
mi = specialize_method(m) # don't allow `Method`-edge for this optimized format
edge = mi
else
mi = edge.def
mi = edge.def::MethodInstance
end
if mi.specTypes === m.spec_types
add_one_edge!(edges, edge)
Expand Down Expand Up @@ -100,7 +100,7 @@ end
function add_one_edge!(edges::Vector{Any}, edge::MethodInstance)
for i in 1:length(edges)
edgeᵢ = edges[i]
edgeᵢ isa CodeInstance && (edgeᵢ = edgeᵢ.def)
edgeᵢ isa CodeInstance && (edgeᵢ = edgeᵢ.def::MethodInstance)
edgeᵢ isa MethodInstance || continue
if edgeᵢ === edge && !(i > 1 && edges[i-1] isa Type)
return # found existing covered edge
Expand All @@ -112,7 +112,7 @@ end
function add_one_edge!(edges::Vector{Any}, edge::CodeInstance)
for i in 1:length(edges)
edgeᵢ_orig = edgeᵢ = edges[i]
edgeᵢ isa CodeInstance && (edgeᵢ = edgeᵢ.def)
edgeᵢ isa CodeInstance && (edgeᵢ = edgeᵢ.def::MethodInstance)
edgeᵢ isa MethodInstance || continue
if edgeᵢ === edge.def && !(i > 1 && edges[i-1] isa Type)
if edgeᵢ_orig isa MethodInstance
Expand Down
6 changes: 6 additions & 0 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,12 @@ struct InitError <: WrappedException
error
end

struct ABIOverwrite
abi::Type
def::MethodInstance
ABIOverwrite(@nospecialize(abi::Type), def::MethodInstance) = new(abi, def)
end

struct PrecompilableError <: Exception end

String(s::String) = s # no constructor yet
Expand Down
8 changes: 7 additions & 1 deletion base/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1353,7 +1353,13 @@ end
show(io::IO, mi::Core.MethodInstance) = show_mi(io, mi)
function show(io::IO, codeinst::Core.CodeInstance)
print(io, "CodeInstance for ")
show_mi(io, codeinst.def)
def = codeinst.def
if isa(def, Core.ABIOverwrite)
show_mi(io, def.def)
print(io, " (ABI Overriden)")
else
show_mi(io, def::MethodInstance)
end
end

function show_mi(io::IO, mi::Core.MethodInstance, from_stackframe::Bool=false)
Expand Down
12 changes: 6 additions & 6 deletions src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ void jl_get_llvm_mis_impl(void *native_code, arraylist_t* MIs)
jl_native_code_desc_t *data = (jl_native_code_desc_t*)native_code;
auto map = data->jl_fvar_map;
for (auto &ci : map) {
jl_method_instance_t *mi = ci.first->def;
jl_method_instance_t *mi = jl_get_ci_mi(ci.first);
arraylist_push(MIs, mi);
}
}
Expand Down Expand Up @@ -344,11 +344,11 @@ static void compile_workqueue(jl_codegen_params_t &params, CompilationPolicy pol
if ((policy != CompilationPolicy::Default || params.params->trim) &&
jl_atomic_load_relaxed(&codeinst->inferred) == jl_nothing) {
// XXX: SOURCE_MODE_FORCE_SOURCE is wrong here (neither sufficient nor necessary)
codeinst = jl_type_infer(codeinst->def, jl_atomic_load_relaxed(&codeinst->max_world), SOURCE_MODE_FORCE_SOURCE);
codeinst = jl_type_infer(jl_get_ci_mi(codeinst), jl_atomic_load_relaxed(&codeinst->max_world), SOURCE_MODE_FORCE_SOURCE);
}
if (codeinst) {
orc::ThreadSafeModule result_m =
jl_create_ts_module(name_from_method_instance(codeinst->def),
jl_create_ts_module(name_from_method_instance(jl_get_ci_mi(codeinst)),
params.tsctx, params.DL, params.TargetTriple);
auto decls = jl_emit_codeinst(result_m, codeinst, NULL, params);
if (result_m)
Expand Down Expand Up @@ -389,7 +389,7 @@ static void compile_workqueue(jl_codegen_params_t &params, CompilationPolicy pol
proto.decl->setLinkage(GlobalVariable::InternalLinkage);
//protodecl->setAlwaysInline();
jl_init_function(proto.decl, params.TargetTriple);
jl_method_instance_t *mi = codeinst->def;
jl_method_instance_t *mi = jl_get_ci_mi(codeinst);
size_t nrealargs = jl_nparams(mi->specTypes); // number of actual arguments being passed
bool is_opaque_closure = jl_is_method(mi->def.value) && mi->def.method->is_for_opaque_closure;
// TODO: maybe this can be cached in codeinst->specfptr?
Expand Down Expand Up @@ -527,7 +527,7 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
}
else {
JL_GC_PROMISE_ROOTED(codeinst->rettype);
orc::ThreadSafeModule result_m = jl_create_ts_module(name_from_method_instance(codeinst->def),
orc::ThreadSafeModule result_m = jl_create_ts_module(name_from_method_instance(jl_get_ci_mi(codeinst)),
params.tsctx, clone.getModuleUnlocked()->getDataLayout(),
Triple(clone.getModuleUnlocked()->getTargetTriple()));
jl_llvm_functions_t decls = jl_emit_codeinst(result_m, codeinst, NULL, params);
Expand Down Expand Up @@ -2158,7 +2158,7 @@ void jl_get_llvmf_defn_impl(jl_llvmf_dump_t* dump, jl_method_instance_t *mi, jl_
// This would also be nice, but it seems to cause OOMs on the windows32 builder
// To get correct names in the IR this needs to be at least 2
output.debug_level = params.debug_info_level;
auto decls = jl_emit_code(m, mi, src, output);
auto decls = jl_emit_code(m, mi, src, NULL, output);

Function *F = NULL;
if (m) {
Expand Down
69 changes: 40 additions & 29 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3046,9 +3046,8 @@ static bool uses_specsig(jl_value_t *sig, bool needsparams, jl_value_t *rettype,
return false; // jlcall sig won't require any box allocations
}

static std::pair<bool, bool> uses_specsig(jl_method_instance_t *lam, jl_value_t *rettype, bool prefer_specsig)
static std::pair<bool, bool> uses_specsig(jl_value_t *abi, jl_method_instance_t *lam, jl_value_t *rettype, bool prefer_specsig)
{
jl_value_t *sig = lam->specTypes;
bool needsparams = false;
if (jl_is_method(lam->def.method)) {
if ((size_t)jl_subtype_env_size(lam->def.method->sig) != jl_svec_len(lam->sparam_vals))
Expand All @@ -3058,7 +3057,7 @@ static std::pair<bool, bool> uses_specsig(jl_method_instance_t *lam, jl_value_t
needsparams = true;
}
}
return std::make_pair(uses_specsig(sig, needsparams, rettype, prefer_specsig), needsparams);
return std::make_pair(uses_specsig(abi, needsparams, rettype, prefer_specsig), needsparams);
}


Expand Down Expand Up @@ -4373,6 +4372,7 @@ static jl_llvm_functions_t
orc::ThreadSafeModule &TSM,
jl_method_instance_t *lam,
jl_code_info_t *src,
jl_value_t *abi,
jl_value_t *rettype,
jl_codegen_params_t &params);

Expand Down Expand Up @@ -5490,6 +5490,13 @@ static jl_cgval_t emit_invoke(jl_codectx_t &ctx, jl_expr_t *ex, jl_value_t *rt)
return emit_invoke(ctx, lival, argv, nargs, rt);
}

static jl_value_t *get_ci_abi(jl_code_instance_t *ci)
{
if (jl_typeof(ci->def) == (jl_value_t*)jl_abioverwrite_type)
return ((jl_abi_overwrite_t*)ci->def)->abi;
return jl_get_ci_mi(ci)->specTypes;
}

static jl_cgval_t emit_invoke(jl_codectx_t &ctx, const jl_cgval_t &lival, ArrayRef<jl_cgval_t> argv, size_t nargs, jl_value_t *rt)
{
++EmittedInvokes;
Expand Down Expand Up @@ -5525,7 +5532,7 @@ static jl_cgval_t emit_invoke(jl_codectx_t &ctx, const jl_cgval_t &lival, ArrayR
}
else if (invoke != jl_fptr_sparam_addr) {
bool specsig, needsparams;
std::tie(specsig, needsparams) = uses_specsig(mi, codeinst->rettype, ctx.params->prefer_specsig);
std::tie(specsig, needsparams) = uses_specsig(get_ci_abi(codeinst), mi, codeinst->rettype, ctx.params->prefer_specsig);
std::string name;
StringRef protoname;
bool need_to_emit = true;
Expand Down Expand Up @@ -7200,7 +7207,7 @@ Function *emit_tojlinvoke(jl_code_instance_t *codeinst, StringRef theFptrName, M
jl_init_function(f, params.TargetTriple);
if (trim_may_error(params.params->trim)) {
// TODO: Debuginfo!
push_frames(ctx, ctx.linfo, codeinst->def, 1);
push_frames(ctx, ctx.linfo, jl_get_ci_mi(codeinst), 1);
}
jl_name_jlfunc_args(params, f);
//f->setAlwaysInline();
Expand All @@ -7217,7 +7224,7 @@ Function *emit_tojlinvoke(jl_code_instance_t *codeinst, StringRef theFptrName, M
}
else {
theFunc = prepare_call(jlinvoke_func);
theFarg = literal_pointer_val(ctx, (jl_value_t*)codeinst->def);
theFarg = literal_pointer_val(ctx, (jl_value_t*)jl_get_ci_mi(codeinst));
}
theFarg = track_pjlvalue(ctx, theFarg);
auto args = f->arg_begin();
Expand Down Expand Up @@ -8327,13 +8334,13 @@ get_specsig_di(jl_codectx_t &ctx, jl_debugcache_t &debuginfo, jl_value_t *rt, jl
}

/* aka Core.Compiler.tuple_tfunc */
static jl_datatype_t *compute_va_type(jl_method_instance_t *lam, size_t nreq)
static jl_datatype_t *compute_va_type(jl_value_t *sig, size_t nreq)
{
size_t nvargs = jl_nparams(lam->specTypes)-nreq;
size_t nvargs = jl_nparams(sig)-nreq;
jl_svec_t *tupargs = jl_alloc_svec(nvargs);
JL_GC_PUSH1(&tupargs);
for (size_t i = nreq; i < jl_nparams(lam->specTypes); ++i) {
jl_value_t *argType = jl_nth_slot_type(lam->specTypes, i);
for (size_t i = nreq; i < jl_nparams(sig); ++i) {
jl_value_t *argType = jl_nth_slot_type(sig, i);
// n.b. specTypes is required to be a datatype by construction for specsig
if (is_uniquerep_Type(argType))
argType = jl_typeof(jl_tparam0(argType));
Expand Down Expand Up @@ -8373,6 +8380,7 @@ static jl_llvm_functions_t
orc::ThreadSafeModule &TSM,
jl_method_instance_t *lam,
jl_code_info_t *src,
jl_value_t *abi,
jl_value_t *jlrettype,
jl_codegen_params_t &params)
{
Expand Down Expand Up @@ -8453,7 +8461,7 @@ static jl_llvm_functions_t
int n_ssavalues = jl_is_long(src->ssavaluetypes) ? jl_unbox_long(src->ssavaluetypes) : jl_array_nrows(src->ssavaluetypes);
size_t vinfoslen = jl_array_dim0(src->slotflags);
ctx.slots.resize(vinfoslen, jl_varinfo_t(ctx.builder.getContext()));
assert(lam->specTypes); // the specTypes field should always be assigned
assert(abi); // the specTypes field should always be assigned


// create SAvalue locations for SSAValue objects
Expand All @@ -8462,7 +8470,7 @@ static jl_llvm_functions_t
ctx.ssavalue_usecount.assign(n_ssavalues, 0);

bool specsig, needsparams;
std::tie(specsig, needsparams) = uses_specsig(lam, jlrettype, params.params->prefer_specsig);
std::tie(specsig, needsparams) = uses_specsig(abi, lam, jlrettype, params.params->prefer_specsig);

// step 3. some variable analysis
size_t i;
Expand All @@ -8472,7 +8480,7 @@ static jl_llvm_functions_t
jl_sym_t *argname = slot_symbol(ctx, i);
if (argname == jl_unused_sym)
continue;
jl_value_t *ty = jl_nth_slot_type(lam->specTypes, i);
jl_value_t *ty = jl_nth_slot_type(abi, i);
// TODO: jl_nth_slot_type should call jl_rewrap_unionall
// specTypes is required to be a datatype by construction for specsig, but maybe not otherwise
// OpaqueClosure implicitly loads the env
Expand All @@ -8490,7 +8498,7 @@ static jl_llvm_functions_t
if (va && ctx.vaSlot != -1) {
jl_varinfo_t &varinfo = ctx.slots[ctx.vaSlot];
varinfo.isArgument = true;
vatyp = specsig ? compute_va_type(lam, nreq) : (jl_tuple_type);
vatyp = specsig ? compute_va_type(abi, nreq) : (jl_tuple_type);
varinfo.value = mark_julia_type(ctx, (Value*)NULL, false, vatyp);
}

Expand Down Expand Up @@ -8542,7 +8550,7 @@ static jl_llvm_functions_t
ArgNames[i] = name;
}
}
returninfo = get_specsig_function(ctx, M, NULL, declarations.specFunctionObject, lam->specTypes,
returninfo = get_specsig_function(ctx, M, NULL, declarations.specFunctionObject, abi,
jlrettype, ctx.is_opaque_closure, JL_FEAT_TEST(ctx,gcstack_arg),
ArgNames, nreq);
f = cast<Function>(returninfo.decl.getCallee());
Expand Down Expand Up @@ -8576,7 +8584,7 @@ static jl_llvm_functions_t
std::string wrapName;
raw_string_ostream(wrapName) << "jfptr_" << ctx.name << "_" << jl_atomic_fetch_add_relaxed(&globalUniqueGeneratedNames, 1);
declarations.functionObject = wrapName;
size_t nparams = jl_nparams(lam->specTypes);
size_t nparams = jl_nparams(abi);
gen_invoke_wrapper(lam, jlrettype, returninfo, nparams, retarg, ctx.is_opaque_closure, declarations.functionObject, M, ctx.emission_context);
// TODO: add attributes: maybe_mark_argument_dereferenceable(Arg, argType)
// TODO: add attributes: dereferenceable<sizeof(void*) * nreq>
Expand All @@ -8596,10 +8604,10 @@ static jl_llvm_functions_t
declarations.functionObject = needsparams ? "jl_fptr_sparam" : "jl_fptr_args";
}

if (ctx.emission_context.debug_level >= 2 && lam->def.method && jl_is_method(lam->def.method) && lam->specTypes != (jl_value_t*)jl_emptytuple_type) {
if (ctx.emission_context.debug_level >= 2 && lam->def.method && jl_is_method(lam->def.method) && abi != (jl_value_t*)jl_emptytuple_type) {
ios_t sigbuf;
ios_mem(&sigbuf, 0);
jl_static_show_func_sig((JL_STREAM*) &sigbuf, (jl_value_t*)lam->specTypes);
jl_static_show_func_sig((JL_STREAM*) &sigbuf, (jl_value_t*)abi);
f->addFnAttr("julia.fsig", StringRef(sigbuf.buf, sigbuf.size));
ios_close(&sigbuf);
}
Expand Down Expand Up @@ -8656,7 +8664,7 @@ static jl_llvm_functions_t
else if (!specsig)
subrty = debugcache.jl_di_func_sig;
else
subrty = get_specsig_di(ctx, debugcache, jlrettype, lam->specTypes, dbuilder);
subrty = get_specsig_di(ctx, debugcache, jlrettype, abi, dbuilder);
SP = dbuilder.createFunction(nullptr
,dbgFuncName // Name
,f->getName() // LinkageName
Expand Down Expand Up @@ -8987,7 +8995,7 @@ static jl_llvm_functions_t
nullptr, nullptr, /*isboxed*/true, AtomicOrdering::NotAtomic, false, sizeof(void*));
}
else {
jl_value_t *argType = jl_nth_slot_type(lam->specTypes, i);
jl_value_t *argType = jl_nth_slot_type(abi, i);
// TODO: jl_nth_slot_type should call jl_rewrap_unionall?
// specTypes is required to be a datatype by construction for specsig, but maybe not otherwise
bool isboxed = deserves_argbox(argType);
Expand Down Expand Up @@ -9061,10 +9069,10 @@ static jl_llvm_functions_t
assert(vi.boxroot == NULL);
}
else if (specsig) {
ctx.nvargs = jl_nparams(lam->specTypes) - nreq;
ctx.nvargs = jl_nparams(abi) - nreq;
SmallVector<jl_cgval_t, 0> vargs(ctx.nvargs);
for (size_t i = nreq; i < jl_nparams(lam->specTypes); ++i) {
jl_value_t *argType = jl_nth_slot_type(lam->specTypes, i);
for (size_t i = nreq; i < jl_nparams(abi); ++i) {
jl_value_t *argType = jl_nth_slot_type(abi, i);
// n.b. specTypes is required to be a datatype by construction for specsig
bool isboxed = deserves_argbox(argType);
Type *llvmArgType = isboxed ? ctx.types().T_prjlvalue : julia_type_to_llvm(ctx, argType);
Expand Down Expand Up @@ -10011,6 +10019,7 @@ jl_llvm_functions_t jl_emit_code(
orc::ThreadSafeModule &m,
jl_method_instance_t *li,
jl_code_info_t *src,
jl_value_t *abi,
jl_codegen_params_t &params)
{
JL_TIMING(CODEGEN, CODEGEN_LLVM);
Expand All @@ -10019,8 +10028,10 @@ jl_llvm_functions_t jl_emit_code(
assert((params.params == &jl_default_cgparams /* fast path */ || !params.cache ||
compare_cgparams(params.params, &jl_default_cgparams)) &&
"functions compiled with custom codegen params must not be cached");
if (!abi)
abi = li->specTypes;
JL_TRY {
decls = emit_function(m, li, src, src->rettype, params);
decls = emit_function(m, li, src, abi, src->rettype, params);
auto stream = *jl_ExecutionEngine->get_dump_emitted_mi_name_stream();
if (stream) {
jl_printf(stream, "%s\t", decls.specFunctionObject.c_str());
Expand Down Expand Up @@ -10089,11 +10100,11 @@ jl_llvm_functions_t jl_emit_codeinst(
jl_codegen_params_t &params)
{
JL_TIMING(CODEGEN, CODEGEN_Codeinst);
jl_timing_show_method_instance(codeinst->def, JL_TIMING_DEFAULT_BLOCK);
jl_timing_show_method_instance(jl_get_ci_mi(codeinst), JL_TIMING_DEFAULT_BLOCK);
JL_GC_PUSH1(&src);
if (!src) {
src = (jl_code_info_t*)jl_atomic_load_relaxed(&codeinst->inferred);
jl_method_instance_t *mi = codeinst->def;
jl_method_instance_t *mi = jl_get_ci_mi(codeinst);
jl_method_t *def = mi->def.method;
// Check if this is the generic method for opaque closure wrappers -
// if so, this must compile specptr such that it holds the specptr -> invoke wrapper
Expand All @@ -10111,15 +10122,15 @@ jl_llvm_functions_t jl_emit_codeinst(
}
}
assert(jl_egal((jl_value_t*)jl_atomic_load_relaxed(&codeinst->debuginfo), (jl_value_t*)src->debuginfo) && "trying to generate code for a codeinst for an incompatible src");
jl_llvm_functions_t decls = jl_emit_code(m, codeinst->def, src, params);
jl_llvm_functions_t decls = jl_emit_code(m, jl_get_ci_mi(codeinst), src, get_ci_abi(codeinst), params);

const std::string &specf = decls.specFunctionObject;
const std::string &f = decls.functionObject;
if (params.cache && !f.empty()) {
// Prepare debug info to receive this function
// record that this function name came from this linfo,
// so we can build a reverse mapping for debug-info.
bool toplevel = !jl_is_method(codeinst->def->def.method);
bool toplevel = !jl_is_method(jl_get_ci_mi(codeinst)->def.method);
if (!toplevel) {
//Safe b/c params holds context lock
const DataLayout &DL = m.getModuleUnlocked()->getDataLayout();
Expand All @@ -10135,7 +10146,7 @@ jl_llvm_functions_t jl_emit_codeinst(
jl_value_t *inferred = jl_atomic_load_relaxed(&codeinst->inferred);
// don't change inferred state
if (inferred) {
jl_method_t *def = codeinst->def->def.method;
jl_method_t *def = jl_get_ci_mi(codeinst)->def.method;
if (// keep code when keeping everything
!(JL_DELETE_NON_INLINEABLE) ||
// aggressively keep code when debugging level >= 2
Expand Down
4 changes: 2 additions & 2 deletions src/debuginfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ void jl_add_code_in_flight(StringRef name, jl_code_instance_t *codeinst, const D
// Non-opaque-closure MethodInstances are considered globally rooted
// through their methods, but for OC, we need to create a global root
// here.
jl_method_instance_t *mi = codeinst->def;
jl_method_instance_t *mi = jl_get_ci_mi(codeinst);
if (jl_is_method(mi->def.value) && mi->def.method->is_for_opaque_closure)
jl_as_global_root((jl_value_t*)mi, 1);
getJITDebugRegistry().add_code_in_flight(name, codeinst, DL);
Expand Down Expand Up @@ -374,7 +374,7 @@ void JITDebugInfoRegistry::registerJITObject(const object::ObjectFile &Object,
jl_method_instance_t *mi = NULL;
if (codeinst) {
JL_GC_PROMISE_ROOTED(codeinst);
mi = codeinst->def;
mi = jl_get_ci_mi(codeinst);
}
jl_profile_atomic([&]() JL_NOTSAFEPOINT {
if (mi)
Expand Down
2 changes: 1 addition & 1 deletion src/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ void jl_engine_fulfill(jl_code_instance_t *ci, jl_code_info_t *src)
{
jl_task_t *ct = jl_current_task;
std::unique_lock lock(engine_lock);
auto record = Reservations.find(InferKey{ci->def, ci->owner});
auto record = Reservations.find(InferKey{jl_get_ci_mi(ci), ci->owner});
if (record == Reservations.end() || record->second.ci != ci)
return;
assert(jl_atomic_load_relaxed(&ct->tid) == record->second.tid);
Expand Down
Loading

0 comments on commit 31ee1b4

Please sign in to comment.