Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 101 additions & 23 deletions src/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,17 @@ end
# `InferAddressSpaces` run folds the entry cast away. The source need not be a constant
# global; any pointer with a known address space qualifies, so any back-end can run it.
#
# With typed pointers (older LLVM) a `Ptr` argument is lowered to an integer rather than a
# pointer, so the same boundary crossing arrives as `ptrtoint(addrspacecast(<specific> ->
# generic))` at the call site and `inttoptr` in the body. Such an integer parameter is
# retargeted the same way: it becomes a pointer in the agreed space, the call sites pass the
# bare source pointer, and entry rebuilds the original integer as `ptrtoint(addrspacecast(
# param -> generic))`. The cloned body's `inttoptr` then composes with that and the same
# `InferAddressSpaces` run folds the whole chain to a specific-space load. (Without this the
# leftover `ptrtoint(addrspacecast(...))` constant feeds a generic-space load that, e.g., the
# LLVM-16 Metal bitcode downgrade miscompiles into an invalid metallib — JuliaGPU/Metal.jl
# device exceptions on Julia 1.11.)
#
# Narrowing one function makes its body forward an `addrspacecast`-from-specific to the
# functions it calls, exposing them in turn. We therefore iterate to a fixed point so a
# constant reaches an arbitrarily deep callee (e.g. an exception reporter that delegates to
Expand All @@ -486,6 +497,31 @@ function addrspacecast_to_generic_source(@nospecialize(v))
return src
end

# The typed-pointer counterpart of `addrspacecast_to_generic_source`. With typed pointers a
# `Ptr` argument is lowered to an integer, so a specific-space pointer crossing a call boundary
# arrives as `ptrtoint(addrspacecast(<ptr in a specific space> -> generic))` rather than the
# bare cast. If `v` is that shape, return the specific-space source pointer; otherwise `nothing`.
function ptrtoint_of_generic_source(@nospecialize(v))
(v isa LLVM.Instruction || v isa LLVM.ConstantExpr) || return nothing
opcode(v) == LLVM.API.LLVMPtrToInt || return nothing
return addrspacecast_to_generic_source(operands(v)[1])
end

# If integer parameter `arg` is consumed only by `inttoptr` (all uses agreeing on the result
# pointer type), return that type; otherwise `nothing`. Such a parameter is the integer image
# of a pointer that crossed the call boundary, and can be retargeted to a pointer the same way
# a generic pointer parameter is (see `propagate_argument_address_spaces!`, the integer case).
function integer_param_pointer_type(arg::LLVM.Argument)
ptrty = nothing
for use in uses(arg)
u = user(use)
(u isa LLVM.Instruction && opcode(u) == LLVM.API.LLVMIntToPtr) || return nothing
t = value_type(u)
ptrty === nothing ? (ptrty = t) : (ptrty == t || return nothing)
end
return ptrty
end

function propagate_argument_address_spaces!(mod::LLVM.Module)
changed = false
while propagate_argument_address_spaces_once!(mod)
Expand Down Expand Up @@ -521,24 +557,39 @@ function propagate_argument_address_spaces_once!(mod::LLVM.Module)
end
(only_calls && !isempty(callsites)) || continue

# for each generic pointer parameter, find the address space its callers agree on
# for each narrowable parameter, find the address space its callers agree on. a
# generic pointer parameter is passed as `addrspacecast(<specific> -> generic)`; with
# typed pointers a `Ptr` argument is instead an integer passed as `ptrtoint` of that
# cast (`int_ptr_types[i]` records the pointer it is reconstructed to, marking Case B).
new_addrspaces = fill(-1, length(param_types))
int_ptr_types = Vector{Any}(nothing, length(param_types))
for (i, pty) in enumerate(param_types)
(pty isa LLVM.PointerType && addrspace(pty) == 0) || continue
extract = if pty isa LLVM.PointerType && addrspace(pty) == 0
addrspacecast_to_generic_source
elseif pty isa LLVM.IntegerType &&
integer_param_pointer_type(parameters(f)[i]) !== nothing
ptrtoint_of_generic_source
else
continue
end
as = -1
for cs in callsites
src = addrspacecast_to_generic_source(arguments(cs)[i])
src = extract(arguments(cs)[i])
if src === nothing
as = -1; break
end
src_as = addrspace(value_type(src))
as == -1 ? (as = src_as) : (as == src_as || (as = -1; break))
end
as > 0 && (new_addrspaces[i] = as)
if as > 0
new_addrspaces[i] = as
pty isa LLVM.IntegerType &&
(int_ptr_types[i] = integer_param_pointer_type(parameters(f)[i]))
end
end
any(>=(0), new_addrspaces) || continue

narrow_pointer_parameters!(mod, f, new_addrspaces, callsites)
narrow_pointer_parameters!(mod, f, new_addrspaces, int_ptr_types, callsites)
changed = true
end
return changed
Expand Down Expand Up @@ -566,11 +617,22 @@ end
# convention, operand bundles and attributes; replaces and erases the old call.
function rewrite_narrowed_call!(builder::IRBuilder, cs::LLVM.CallInst,
new_f::LLVM.Function, new_ft::LLVM.FunctionType,
new_addrspaces::Vector{Int})
new_addrspaces::Vector{Int}, int_ptr_types::Vector{Any})
position!(builder, cs)
new_args = LLVM.Value[new_addrspaces[i] >= 0 ?
addrspacecast_to_generic_source(arg) : arg
for (i, arg) in enumerate(arguments(cs))]
new_param_types = parameters(new_ft)
new_args = LLVM.Value[]
for (i, arg) in enumerate(arguments(cs))
if new_addrspaces[i] < 0
push!(new_args, arg)
elseif int_ptr_types[i] !== nothing
# Case B: strip the `ptrtoint` then the cast, and bitcast the bare specific-space
# pointer to the retargeted parameter's pointer type
src = addrspacecast_to_generic_source(operands(arg)[1])
push!(new_args, bitcast!(builder, src, new_param_types[i]))
else
push!(new_args, addrspacecast_to_generic_source(arg))
end
end
new_call = call!(builder, new_ft, new_f, new_args, operand_bundles(cs))
callconv!(new_call, callconv(cs))
copy_callsite_attributes!(new_call, cs)
Expand All @@ -584,14 +646,22 @@ end
# back to generic on entry so the cloned body is unchanged. Rewrite `callsites` to pass the
# un-casted source value for each retargeted argument; recursive self-calls are handled too.
function narrow_pointer_parameters!(mod::LLVM.Module, f::LLVM.Function,
new_addrspaces::Vector{Int}, callsites)
new_addrspaces::Vector{Int}, int_ptr_types::Vector{Any},
callsites)
ft = function_type(f)
retarget(pty::LLVM.PointerType, as::Integer) =
supports_typed_pointers(context()) ? LLVM.PointerType(eltype(pty), as) :
# retarget a pointer to address space `as`, taking its pointee from `srcptr` (only needed
# for typed pointers; `eltype` is invalid on opaque ones, so keep it lazy)
retarget(as::Integer, srcptr::LLVM.PointerType) =
supports_typed_pointers(context()) ? LLVM.PointerType(eltype(srcptr), as) :
LLVM.PointerType(as)
new_types = LLVM.LLVMType[new_addrspaces[i] >= 0 ?
retarget(param_typ::LLVM.PointerType, new_addrspaces[i]) :
param_typ
# the retargeted parameter type: for a pointer parameter (Case A) keep its pointee; for an
# integer parameter (Case B) use the pointee of the pointer its body reconstructs.
new_param_type(i, param_typ) =
new_addrspaces[i] < 0 ? param_typ :
int_ptr_types[i] !== nothing ?
retarget(new_addrspaces[i], int_ptr_types[i]::LLVM.PointerType) :
retarget(new_addrspaces[i], param_typ::LLVM.PointerType)
new_types = LLVM.LLVMType[new_param_type(i, param_typ)
for (i, param_typ) in enumerate(parameters(ft))]
new_ft = LLVM.FunctionType(return_type(ft), new_types)

Expand All @@ -609,10 +679,17 @@ function narrow_pointer_parameters!(mod::LLVM.Module, f::LLVM.Function,
position!(builder, entry)
new_args = LLVM.Value[]
for (i, param_typ) in enumerate(parameters(ft))
if new_addrspaces[i] >= 0
push!(new_args, addrspacecast!(builder, parameters(new_f)[i], param_typ))
else
if new_addrspaces[i] < 0
push!(new_args, parameters(new_f)[i])
elseif int_ptr_types[i] !== nothing
# Case B: rebuild the original integer as `ptrtoint(addrspacecast(param ->
# generic))`; the cloned body's `inttoptr` composes with it, and the following
# InferAddressSpaces run folds the whole chain to a specific-space load.
gen = addrspacecast!(builder, parameters(new_f)[i],
int_ptr_types[i]::LLVM.PointerType)
push!(new_args, ptrtoint!(builder, gen, param_typ))
else
push!(new_args, addrspacecast!(builder, parameters(new_f)[i], param_typ))
end
end

Expand All @@ -626,10 +703,11 @@ function narrow_pointer_parameters!(mod::LLVM.Module, f::LLVM.Function,
end

# `clone_into!` copies a parameter's attributes only when it maps to a new argument; the
# retargeted ones map to the entry addrspacecast instead, so theirs are dropped. Reattach
# them; they stay valid on the narrowed pointer, and non-retargeted params keep theirs.
# retargeted ones map to the entry cast instead, so theirs are dropped. Reattach them for
# Case A (still valid on the narrowed pointer). Skip Case B: those were integer attributes
# (e.g. `zeroext`) that are invalid on the now-pointer parameter.
for i in 1:length(new_addrspaces)
new_addrspaces[i] >= 0 || continue
(new_addrspaces[i] >= 0 && int_ptr_types[i] === nothing) || continue
for attr in collect(parameter_attributes(f, i))
push!(parameter_attributes(new_f, i), attr)
end
Expand All @@ -646,10 +724,10 @@ function narrow_pointer_parameters!(mod::LLVM.Module, f::LLVM.Function,
# rewrite call sites to pass the un-casted source value for each retargeted argument
@dispose builder=IRBuilder() begin
for cs in callsites
rewrite_narrowed_call!(builder, cs, new_f, new_ft, new_addrspaces)
rewrite_narrowed_call!(builder, cs, new_f, new_ft, new_addrspaces, int_ptr_types)
end
for cs in self_calls
rewrite_narrowed_call!(builder, cs, new_f, new_ft, new_addrspaces)
rewrite_narrowed_call!(builder, cs, new_f, new_ft, new_addrspaces, int_ptr_types)
end
end

Expand Down
69 changes: 69 additions & 0 deletions test/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,75 @@ end
@test (verify(mod); true)
end

# typed-pointer form: a `Ptr` argument crosses the boundary as an integer, so the callee
# takes an integer it `inttoptr`s and the callers pass `ptrtoint(addrspacecast(<global> ->
# generic))`. the integer parameter is narrowed to a constant-space pointer just the same,
# and the call sites pass the bare global. (regression: the pass used to skip integer
# parameters, leaving a `ptrtoint(addrspacecast(...))` constant that the LLVM-16 Metal
# bitcode downgrade miscompiles -- device exceptions crashed on Julia 1.11.)
Context() do ctx
i8 = LLVM.Int8Type()
i64 = LLVM.Int64Type()
mod = LLVM.Module("test")
callee_ft = LLVM.FunctionType(i8, LLVM.LLVMType[i64])
callee = LLVM.Function(mod, "callee", callee_ft)
linkage!(callee, LLVM.API.LLVMInternalLinkage)
@dispose builder=IRBuilder() begin
position!(builder, BasicBlock(callee, "entry"))
p = inttoptr!(builder, parameters(callee)[1], asptr(0))
ret!(builder, load!(builder, i8, p))
end
for n in 1:2
g = GlobalVariable(mod, i8, "g$n", 2)
initializer!(g, ConstantInt(i8, n)); constant!(g, true)
caller = LLVM.Function(mod, "caller$n", LLVM.FunctionType(i8, LLVM.LLVMType[]))
linkage!(caller, LLVM.API.LLVMInternalLinkage)
@dispose builder=IRBuilder() begin
position!(builder, BasicBlock(caller, "entry"))
arg = const_ptrtoint(const_addrspacecast(g, asptr(0)), i64)
ret!(builder, call!(builder, callee_ft, callee, [arg]))
end
end

@test GPUCompiler.propagate_argument_address_spaces!(mod)
param = parameters(function_type(functions(mod)["callee"]))[1]
@test param isa LLVM.PointerType && addrspace(param) == 2
@test all(c -> value_type(arguments(c)[1]) isa LLVM.PointerType &&
addrspace(value_type(arguments(c)[1])) == 2, calls_to(mod, "callee"))
@test (verify(mod); true)
end

# an integer parameter used as more than a pointer image (here also in arithmetic) is
# left alone: narrowing it would lose the integer's other uses
Context() do ctx
i8 = LLVM.Int8Type()
i64 = LLVM.Int64Type()
mod = LLVM.Module("test")
callee_ft = LLVM.FunctionType(i8, LLVM.LLVMType[i64])
callee = LLVM.Function(mod, "callee", callee_ft)
linkage!(callee, LLVM.API.LLVMInternalLinkage)
@dispose builder=IRBuilder() begin
position!(builder, BasicBlock(callee, "entry"))
p = inttoptr!(builder, parameters(callee)[1], asptr(0))
v = load!(builder, i8, p)
# a second, non-`inttoptr` use of the integer parameter
extra = trunc!(builder, add!(builder, parameters(callee)[1], parameters(callee)[1]), i8)
ret!(builder, add!(builder, v, extra))
end
g = GlobalVariable(mod, i8, "g", 2)
initializer!(g, ConstantInt(i8, 1)); constant!(g, true)
caller = LLVM.Function(mod, "caller", LLVM.FunctionType(i8, LLVM.LLVMType[]))
linkage!(caller, LLVM.API.LLVMInternalLinkage)
@dispose builder=IRBuilder() begin
position!(builder, BasicBlock(caller, "entry"))
arg = const_ptrtoint(const_addrspacecast(g, asptr(0)), i64)
ret!(builder, call!(builder, callee_ft, callee, [arg]))
end

@test !GPUCompiler.propagate_argument_address_spaces!(mod)
@test parameters(function_type(functions(mod)["callee"]))[1] == i64
end

# a two-level delegation chain (caller -> mid -> leaf) needs the fixpoint: one sweep
# narrows `mid` (its caller passes a constant global), which only then exposes `leaf`,
# since `mid` now forwards an addrspacecast-from-constant. iterate until both narrow.
Expand Down
Loading