Skip to content

Commit

Permalink
Add ABIOverride type for def field (#56555)
Browse files Browse the repository at this point in the history
Together with #54899, this PR is intending to replicate the
functionality of #54373, which allowed particular specializations to
have a different ABI signature than what would be suggested by the
MethodInstance's `specTypes` field. This PR handles that by adding a
special `ABIOverwrite` type, which, when placed in the `owner` field of
a `CodeInstance` instructs the system to use the given signature
instead.
  • Loading branch information
Keno authored Dec 20, 2024
1 parent 6e04a0b commit 9bc27ad
Show file tree
Hide file tree
Showing 28 changed files with 177 additions and 100 deletions.
2 changes: 1 addition & 1 deletion Compiler/src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ ccall(:jl_set_module_uuid, Cvoid, (Any, NTuple{2, UInt64}), Compiler,

using Core.Intrinsics, Core.IR

using Core: Builtin, CodeInstance, IntrinsicFunction, MethodInstance, MethodMatch,
using Core: ABIOverride, Builtin, CodeInstance, IntrinsicFunction, MethodInstance, MethodMatch,
MethodTable, PartialOpaque, SimpleVector, TypeofVararg,
_apply_iterate, apply_type, compilerbarrier, donotdelete, memoryref_isassigned,
memoryrefget, memoryrefnew, memoryrefoffset, memoryrefset!, print, println, show, svec,
Expand Down
18 changes: 15 additions & 3 deletions Compiler/src/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2212,6 +2212,18 @@ function abstract_call_unionall(interp::AbstractInterpreter, argtypes::Vector{An
return CallMeta(ret, Any, Effects(EFFECTS_TOTAL; nothrow), call.info)
end

function ci_abi(ci::CodeInstance)
def = ci.def
isa(def, ABIOverride) && return def.abi
(def::MethodInstance).specTypes
end

function get_ci_mi(ci::CodeInstance)
def = ci.def
isa(def, ABIOverride) && return def.def
return def::MethodInstance
end

function abstract_invoke(interp::AbstractInterpreter, arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState)
argtypes = arginfo.argtypes
ft′ = argtype_by_index(argtypes, 2)
Expand All @@ -2223,12 +2235,12 @@ function abstract_invoke(interp::AbstractInterpreter, arginfo::ArgInfo, si::Stmt
if isa(method_or_ci, CodeInstance)
our_world = sv.world.this
argtype = argtypes_to_type(pushfirst!(argtype_tail(argtypes, 4), ft))
specsig = method_or_ci.def.specTypes
defdef = method_or_ci.def.def
specsig = ci_abi(method_or_ci)
defdef = get_ci_mi(method_or_ci).def
exct = method_or_ci.exctype
if !hasintersect(argtype, specsig)
return Future(CallMeta(Bottom, TypeError, EFFECTS_THROWS, NoCallInfo()))
elseif !(argtype <: specsig) || (isa(defdef, Method) && !(argtype <: defdef.sig))
elseif !(argtype <: specsig) || ((!isa(method_or_ci.def, ABIOverride) && isa(defdef, Method)) && !(argtype <: defdef.sig))
exct = Union{exct, TypeError}
end
callee_valid_range = WorldRange(method_or_ci.min_world, method_or_ci.max_world)
Expand Down
8 changes: 7 additions & 1 deletion Compiler/src/ssair/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,17 @@ function print_stmt(io::IO, idx::Int, @nospecialize(stmt), code::Union{IRCode,Co
if !(mi isa Core.MethodInstance)
mi = (mi::Core.CodeInstance).def
end
if isa(mi, Core.ABIOverride)
abi = mi.abi
mi = mi.def
else
abi = mi.specTypes
end
show_unquoted(io, stmt.args[2], indent)
print(io, "(")
# XXX: this is wrong if `sig` is not a concretetype method
# more correct would be to use `fieldtype(sig, i)`, but that would obscure / discard Varargs information in show
sig = mi.specTypes == Tuple ? Core.svec() : Base.unwrap_unionall(mi.specTypes).parameters::Core.SimpleVector
sig = abi == Tuple ? Core.svec() : Base.unwrap_unionall(abi).parameters::Core.SimpleVector
print_arg(i) = sprint(; context=io) do io
show_unquoted(io, stmt.args[i], indent)
if (i - 1) <= length(sig)
Expand Down
8 changes: 4 additions & 4 deletions Compiler/src/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ function _add_edges_impl(edges::Vector{Any}, info::MethodMatchInfo, mi_edge::Boo
mi = specialize_method(m) # don't allow `Method`-edge for this optimized format
edge = mi
else
mi = edge.def
mi = edge.def::MethodInstance
end
if mi.specTypes === m.spec_types
add_one_edge!(edges, edge)
Expand Down Expand Up @@ -103,7 +103,7 @@ function add_one_edge!(edges::Vector{Any}, edge::MethodInstance)
while i <= length(edges)
edgeᵢ = edges[i]
edgeᵢ isa Int && (i += 2 + edgeᵢ; continue)
edgeᵢ isa CodeInstance && (edgeᵢ = edgeᵢ.def)
edgeᵢ isa CodeInstance && (edgeᵢ = get_ci_mi(edgeᵢ))
edgeᵢ isa MethodInstance || (i += 1; continue)
if edgeᵢ === edge && !(i > 1 && edges[i-1] isa Type)
return # found existing covered edge
Expand All @@ -118,7 +118,7 @@ function add_one_edge!(edges::Vector{Any}, edge::CodeInstance)
while i <= length(edges)
edgeᵢ_orig = edgeᵢ = edges[i]
edgeᵢ isa Int && (i += 2 + edgeᵢ; continue)
edgeᵢ isa CodeInstance && (edgeᵢ = edgeᵢ.def)
edgeᵢ isa CodeInstance && (edgeᵢ = get_ci_mi(edgeᵢ))
edgeᵢ isa MethodInstance || (i += 1; continue)
if edgeᵢ === edge.def && !(i > 1 && edges[i-1] isa Type)
if edgeᵢ_orig isa MethodInstance
Expand Down Expand Up @@ -385,7 +385,7 @@ function add_inlining_edge!(edges::Vector{Any}, edge::CodeInstance)
i += 1
end
# add_invoke_edge alone
push!(edges, (edge.def.def::Method).sig)
push!(edges, (get_ci_mi(edge).def::Method).sig)
push!(edges, edge)
nothing
end
Expand Down
2 changes: 1 addition & 1 deletion Compiler/src/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ function store_backedges(caller::CodeInstance, edges::SimpleVector)
i += 2
continue
elseif isa(callee, CodeInstance)
callee = callee.def
callee = get_ci_mi(callee)
end
ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), callee, item, caller)
i += 2
Expand Down
7 changes: 6 additions & 1 deletion Compiler/test/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4110,7 +4110,12 @@ end == [Union{Some{Float64}, Some{Int}, Some{UInt8}}]
mi = codeinst
else
codeinst::Core.CodeInstance
mi = codeinst.def
def = codeinst.def
if isa(def, Core.ABIOverride)
mi = def.def
else
mi = def::Core.MethodInstance
end
end
return mi
end == Core.MethodInstance
Expand Down
2 changes: 1 addition & 1 deletion base/arrayshow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Accept keyword args `c` for alternate single character marker.
"""
function replace_with_centered_mark(s::AbstractString;c::AbstractChar = '')
N = textwidth(ANSIIterator(s))
return join(setindex!([" " for i=1:N],string(c),ceil(Int,N/2)))
return N == 0 ? string(c) : join(setindex!([" " for i=1:N],string(c),ceil(Int,N/2)))
end

const undef_ref_alignment = (3,3)
Expand Down
8 changes: 7 additions & 1 deletion base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,12 @@ struct InitError <: WrappedException
error
end

struct ABIOverride
abi::Type
def::MethodInstance
ABIOverride(@nospecialize(abi::Type), def::MethodInstance) = new(abi, def)
end

struct PrecompilableError <: Exception end

String(s::String) = s # no constructor yet
Expand Down Expand Up @@ -552,7 +558,7 @@ end


function CodeInstance(
mi::MethodInstance, owner, @nospecialize(rettype), @nospecialize(exctype), @nospecialize(inferred_const),
mi::Union{MethodInstance, ABIOverride}, owner, @nospecialize(rettype), @nospecialize(exctype), @nospecialize(inferred_const),
@nospecialize(inferred), const_flags::Int32, min_world::UInt, max_world::UInt,
effects::UInt32, @nospecialize(analysis_results),
relocatability::UInt8, di::Union{DebugInfo,Nothing}, edges::SimpleVector)
Expand Down
8 changes: 7 additions & 1 deletion base/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1353,7 +1353,13 @@ end
show(io::IO, mi::Core.MethodInstance) = show_mi(io, mi)
function show(io::IO, codeinst::Core.CodeInstance)
print(io, "CodeInstance for ")
show_mi(io, codeinst.def)
def = codeinst.def
if isa(def, Core.ABIOverride)
show_mi(io, def.def)
print(io, " (ABI Overridden)")
else
show_mi(io, def::MethodInstance)
end
end

function show_mi(io::IO, mi::Core.MethodInstance, from_stackframe::Bool=false)
Expand Down
16 changes: 8 additions & 8 deletions src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ jl_get_llvm_mis_impl(void *native_code, size_t *num_elements, jl_method_instance
assert(*num_elements == map.size());
size_t i = 0;
for (auto &ci : map) {
data[i++] = ci.first->def;
data[i++] = jl_get_ci_mi(ci.first);
}
}

Expand Down Expand Up @@ -455,14 +455,14 @@ static void compile_workqueue(jl_codegen_params_t &params, egal_set &method_root
if ((policy != CompilationPolicy::Default || params.params->trim) &&
jl_atomic_load_relaxed(&codeinst->inferred) == jl_nothing) {
// XXX: SOURCE_MODE_FORCE_SOURCE is wrong here (neither sufficient nor necessary)
codeinst = jl_type_infer(codeinst->def, jl_atomic_load_relaxed(&codeinst->max_world), SOURCE_MODE_FORCE_SOURCE);
codeinst = jl_type_infer(jl_get_ci_mi(codeinst), jl_atomic_load_relaxed(&codeinst->max_world), SOURCE_MODE_FORCE_SOURCE);
}
if (codeinst) {
orc::ThreadSafeModule result_m =
jl_create_ts_module(name_from_method_instance(codeinst->def),
jl_create_ts_module(name_from_method_instance(jl_get_ci_mi(codeinst)),
params.tsctx, params.DL, params.TargetTriple);
auto decls = jl_emit_codeinst(result_m, codeinst, NULL, params);
record_method_roots(method_roots, codeinst->def);
record_method_roots(method_roots, jl_get_ci_mi(codeinst));
if (result_m)
it = compiled_functions.insert(std::make_pair(codeinst, std::make_pair(std::move(result_m), std::move(decls)))).first;
}
Expand Down Expand Up @@ -501,7 +501,7 @@ static void compile_workqueue(jl_codegen_params_t &params, egal_set &method_root
proto.decl->setLinkage(GlobalVariable::InternalLinkage);
//protodecl->setAlwaysInline();
jl_init_function(proto.decl, params.TargetTriple);
jl_method_instance_t *mi = codeinst->def;
jl_method_instance_t *mi = jl_get_ci_mi(codeinst);
size_t nrealargs = jl_nparams(mi->specTypes); // number of actual arguments being passed
bool is_opaque_closure = jl_is_method(mi->def.value) && mi->def.method->is_for_opaque_closure;
// TODO: maybe this can be cached in codeinst->specfptr?
Expand Down Expand Up @@ -641,12 +641,12 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
data->jl_fvar_map[codeinst] = std::make_tuple((uint32_t)-3, (uint32_t)-3);
}
else {
orc::ThreadSafeModule result_m = jl_create_ts_module(name_from_method_instance(codeinst->def),
orc::ThreadSafeModule result_m = jl_create_ts_module(name_from_method_instance(jl_get_ci_mi(codeinst)),
params.tsctx, clone.getModuleUnlocked()->getDataLayout(),
Triple(clone.getModuleUnlocked()->getTargetTriple()));
jl_llvm_functions_t decls = jl_emit_codeinst(result_m, codeinst, NULL, params);
JL_GC_PROMISE_ROOTED(codeinst->def); // analyzer seems confused
record_method_roots(method_roots, codeinst->def);
record_method_roots(method_roots, jl_get_ci_mi(codeinst));
if (result_m)
compiled_functions[codeinst] = {std::move(result_m), std::move(decls)};
else if (jl_options.trim != JL_TRIM_NO) {
Expand Down Expand Up @@ -2267,7 +2267,7 @@ void jl_get_llvmf_defn_impl(jl_llvmf_dump_t* dump, jl_method_instance_t *mi, jl_
// To get correct names in the IR this needs to be at least 2
output.temporary_roots = jl_alloc_array_1d(jl_array_any_type, 0);
JL_GC_PUSH1(&output.temporary_roots);
auto decls = jl_emit_code(m, mi, src, output);
auto decls = jl_emit_code(m, mi, src, NULL, output);
output.temporary_roots = nullptr;
JL_GC_POP(); // GC the global_targets array contents now since reflection doesn't need it

Expand Down
16 changes: 12 additions & 4 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -1589,11 +1589,19 @@ JL_CALLABLE(jl_f_invoke)
return jl_gf_invoke_by_method(m, args[0], &args[2], nargs - 1);
} else if (jl_is_code_instance(argtypes)) {
jl_code_instance_t *codeinst = (jl_code_instance_t*)args[1];
jl_method_instance_t *mi = jl_get_ci_mi(codeinst);
jl_callptr_t invoke = jl_atomic_load_acquire(&codeinst->invoke);
// N.B.: specTypes need not be a subtype of the method signature. We need to check both.
if (!jl_tuple1_isa(args[0], &args[2], nargs - 1, (jl_datatype_t*)codeinst->def->specTypes) ||
(jl_is_method(codeinst->def->def.value) && !jl_tuple1_isa(args[0], &args[2], nargs - 1, (jl_datatype_t*)codeinst->def->def.method->sig))) {
jl_type_error("invoke: argument type error", codeinst->def->specTypes, arg_tuple(args[0], &args[2], nargs - 1));
if (jl_is_abioverride(codeinst->def)) {
jl_datatype_t *abi = (jl_datatype_t*)((jl_abi_override_t*)(codeinst->def))->abi;
if (!jl_tuple1_isa(args[0], &args[2], nargs - 1, abi)) {
jl_type_error("invoke: argument type error (ABI overwrite)", (jl_value_t*)abi, arg_tuple(args[0], &args[2], nargs - 1));
}
} else {
if (!jl_tuple1_isa(args[0], &args[2], nargs - 1, (jl_datatype_t*)mi->specTypes) ||
(jl_is_method(mi->def.value) && !jl_tuple1_isa(args[0], &args[2], nargs - 1, (jl_datatype_t*)mi->def.method->sig))) {
jl_type_error("invoke: argument type error", mi->specTypes, arg_tuple(args[0], &args[2], nargs - 1));
}
}
if (jl_atomic_load_relaxed(&codeinst->min_world) > jl_current_task->world_age ||
jl_current_task->world_age > jl_atomic_load_relaxed(&codeinst->max_world)) {
Expand All @@ -1609,7 +1617,7 @@ JL_CALLABLE(jl_f_invoke)
if (codeinst->owner != jl_nothing) {
jl_error("Failed to invoke or compile external codeinst");
}
return jl_invoke(args[0], &args[2], nargs - 1, codeinst->def);
return jl_invoke(args[0], &args[2], nargs - 1, mi);
}
}
if (!jl_is_tuple_type(jl_unwrap_unionall(argtypes)))
Expand Down
1 change: 1 addition & 0 deletions src/clangsa/GCChecker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,7 @@ bool GCChecker::isGCTrackedType(QualType QT) {
Name.ends_with_insensitive("jl_vararg_t") ||
Name.ends_with_insensitive("jl_opaque_closure_t") ||
Name.ends_with_insensitive("jl_globalref_t") ||
Name.ends_with_insensitive("jl_abi_override_t") ||
// Probably not technically true for these, but let's allow it as a root
Name.ends_with_insensitive("jl_ircode_state") ||
Name.ends_with_insensitive("typemap_intersection_env") ||
Expand Down
Loading

0 comments on commit 9bc27ad

Please sign in to comment.