Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "GPUCompiler"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "1.15.1"
version = "1.15.2"
authors = ["Tim Besard <tim.besard@gmail.com>"]

[workspace]
Expand Down
110 changes: 72 additions & 38 deletions src/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -466,23 +466,32 @@ end
# `InferAddressSpaces` run folds the entry cast away. The source need not be a constant
# global; any pointer with a known address space qualifies, so any back-end can run it.
#
# With typed pointers (older LLVM) a `Ptr` argument is lowered to an integer rather than a
# pointer, so the same boundary crossing arrives as `ptrtoint(addrspacecast(<specific> ->
# generic))` at the call site and `inttoptr` in the body. Such an integer parameter is
# retargeted the same way: it becomes a pointer in the agreed space, the call sites pass the
# bare source pointer, and entry rebuilds the original integer as `ptrtoint(addrspacecast(
# param -> generic))`. The cloned body's `inttoptr` then composes with that and the same
# `InferAddressSpaces` run folds the whole chain to a specific-space load. (Without this the
# leftover `ptrtoint(addrspacecast(...))` constant feeds a generic-space load that, e.g., the
# TYPED-POINTER SHIM (Julia <= 1.11) -- gated on `supports_typed_pointers(context())` and
# removable, along with everything else tagged "typed-pointer shim", once 1.12 is the minimum.
# With typed pointers a `Ptr` argument is lowered to an integer rather than a pointer, so the
# same boundary crossing arrives as `ptrtoint(addrspacecast(<specific> -> generic))` at the call
# site, and the body either `inttoptr`s the integer (a leaf reporter that dereferences it) or
# forwards it on to a further call (a reporter that delegates, e.g. `report_exception_name` ->
# `report_exception`). Either way the integer is the image of a specific-space pointer, and is
# retargeted the same way as a generic pointer parameter: it becomes a pointer in the agreed
# space, the call sites pass the bare source pointer, and entry rebuilds the original integer as
# `ptrtoint(addrspacecast(param -> generic))`. This is sound for the same reason as the pointer
# case -- it only relocates a value-preserving cast/round-trip across the boundary, so every use
# (deref or forward) sees the bit-identical integer. A leaf's cloned `inttoptr` then composes
# with the rebuilt integer and `InferAddressSpaces` folds the chain to a specific-space load; a
# delegating function instead re-exposes the `ptrtoint(addrspacecast(...))` shape at its
# forwarded call, which the next sweep narrows in turn (see the fixed point below). (Without
# this the leftover `ptrtoint(addrspacecast(...))` feeds a generic-space load that, e.g., the
# LLVM-16 Metal bitcode downgrade miscompiles into an invalid metallib — JuliaGPU/Metal.jl
# device exceptions on Julia 1.11.)
#
# Narrowing one function makes its body forward an `addrspacecast`-from-specific to the
# functions it calls, exposing them in turn. We therefore iterate to a fixed point so a
# constant reaches an arbitrarily deep callee (e.g. an exception reporter that delegates to
# another) regardless of the order functions are visited in. This terminates: each sweep
# that changes anything strictly reduces the number of generic pointer parameters in the
# module, and narrowing never introduces a new one.
# Narrowing one function makes its body forward an `addrspacecast`-from-specific (Case A) or, in
# the shim, a `ptrtoint(addrspacecast(...))` (Case B) to the functions it calls, exposing them
# in turn. We therefore iterate to a fixed point so a constant reaches an arbitrarily deep callee
# (e.g. an exception reporter that delegates to another) regardless of the order functions are
# visited in. This terminates: each sweep that changes anything strictly reduces the number of
# narrowable (generic-pointer or, in the shim, integer-image) parameters in the module, and
# narrowing never introduces a new one.

# If `v` is an `addrspacecast` (instruction or constant expression) of a pointer from a
# specific (non-generic) address space to the generic one, return that source pointer;
Expand All @@ -497,29 +506,48 @@ function addrspacecast_to_generic_source(@nospecialize(v))
return src
end

# The typed-pointer counterpart of `addrspacecast_to_generic_source`. With typed pointers a
# `Ptr` argument is lowered to an integer, so a specific-space pointer crossing a call boundary
# arrives as `ptrtoint(addrspacecast(<ptr in a specific space> -> generic))` rather than the
# bare cast. If `v` is that shape, return the specific-space source pointer; otherwise `nothing`.
# Typed-pointer shim (Julia <= 1.11) -- remove once 1.12 is the minimum. The typed-pointer
# counterpart of `addrspacecast_to_generic_source`: with typed pointers a `Ptr` argument is
# lowered to an integer, so a specific-space pointer crossing a call boundary arrives as
# `ptrtoint(addrspacecast(<ptr in a specific space> -> generic))` rather than the bare cast. If
# `v` is that shape, return the specific-space source pointer; otherwise `nothing`.
function ptrtoint_of_generic_source(@nospecialize(v))
(v isa LLVM.Instruction || v isa LLVM.ConstantExpr) || return nothing
opcode(v) == LLVM.API.LLVMPtrToInt || return nothing
return addrspacecast_to_generic_source(operands(v)[1])
end

# If integer parameter `arg` is consumed only by `inttoptr` (all uses agreeing on the result
# pointer type), return that type; otherwise `nothing`. Such a parameter is the integer image
# of a pointer that crossed the call boundary, and can be retargeted to a pointer the same way
# a generic pointer parameter is (see `propagate_argument_address_spaces!`, the integer case).
function integer_param_pointer_type(arg::LLVM.Argument)
# Typed-pointer shim (Julia <= 1.11) -- remove once 1.12 is the minimum. Classify integer
# parameter `arg` as the integer image of a pointer that crossed a call boundary, returning the
# generic pointer type to reconstruct it to (so it can be retargeted like a generic pointer
# parameter; see `propagate_argument_address_spaces!`), or `nothing` if it is not safely a
# pointer image. It qualifies when every use is either
# * an `inttoptr` to the generic space -- the leaf shape, where the body dereferences it (all
# such uses must agree on the result type, which pins the reconstructed pointee); or
# * a call argument -- the delegation shape, where the body forwards it on unchanged.
# A purely-forwarding parameter has no `inttoptr` to pin the pointee, so a canonical generic
# `i8*` is used: every boundary is a `bitcast`/`ptrtoint`, so the choice only affects the
# bridging casts, not the reconstructed value. Any other use (arithmetic, comparison, storing
# the integer, ...) means it is genuinely an integer, so it is left alone -- narrowing it would
# be value-preserving but pointless, and we have no pointee to reconstruct to.
function integer_param_pointer_image_type(arg::LLVM.Argument)
ptrty = nothing
forwarded = false
for use in uses(arg)
u = user(use)
(u isa LLVM.Instruction && opcode(u) == LLVM.API.LLVMIntToPtr) || return nothing
t = value_type(u)
ptrty === nothing ? (ptrty = t) : (ptrty == t || return nothing)
if u isa LLVM.Instruction && opcode(u) == LLVM.API.LLVMIntToPtr
t = value_type(u)
(t isa LLVM.PointerType && addrspace(t) == 0) || return nothing
ptrty === nothing ? (ptrty = t) : (ptrty == t || return nothing)
elseif u isa LLVM.CallInst
forwarded = true
else
return nothing
end
end
return ptrty
ptrty !== nothing && return ptrty
forwarded && return LLVM.PointerType(LLVM.Int8Type())
return nothing
end

function propagate_argument_address_spaces!(mod::LLVM.Module)
Expand Down Expand Up @@ -566,8 +594,10 @@ function propagate_argument_address_spaces_once!(mod::LLVM.Module)
for (i, pty) in enumerate(param_types)
extract = if pty isa LLVM.PointerType && addrspace(pty) == 0
addrspacecast_to_generic_source
elseif pty isa LLVM.IntegerType &&
integer_param_pointer_type(parameters(f)[i]) !== nothing
# typed-pointer shim (Julia <= 1.11) -- remove once 1.12 is the minimum. under opaque
# pointers a `Ptr` stays a generic pointer (Case A above), so this only matters here.
elseif supports_typed_pointers(context()) && pty isa LLVM.IntegerType &&
integer_param_pointer_image_type(parameters(f)[i]) !== nothing
ptrtoint_of_generic_source
else
continue
Expand All @@ -583,8 +613,10 @@ function propagate_argument_address_spaces_once!(mod::LLVM.Module)
end
if as > 0
new_addrspaces[i] = as
# typed-pointer shim (Julia <= 1.11): an integer candidate only gets here under
# `supports_typed_pointers`, so record the pointer it reconstructs to (Case B).
pty isa LLVM.IntegerType &&
(int_ptr_types[i] = integer_param_pointer_type(parameters(f)[i]))
(int_ptr_types[i] = integer_param_pointer_image_type(parameters(f)[i]))
end
end
any(>=(0), new_addrspaces) || continue
Expand Down Expand Up @@ -625,8 +657,8 @@ function rewrite_narrowed_call!(builder::IRBuilder, cs::LLVM.CallInst,
if new_addrspaces[i] < 0
push!(new_args, arg)
elseif int_ptr_types[i] !== nothing
# Case B: strip the `ptrtoint` then the cast, and bitcast the bare specific-space
# pointer to the retargeted parameter's pointer type
# Case B (typed-pointer shim): strip the `ptrtoint` then the cast, and bitcast the
# bare specific-space pointer to the retargeted parameter's pointer type
src = addrspacecast_to_generic_source(operands(arg)[1])
push!(new_args, bitcast!(builder, src, new_param_types[i]))
else
Expand Down Expand Up @@ -655,7 +687,8 @@ function narrow_pointer_parameters!(mod::LLVM.Module, f::LLVM.Function,
supports_typed_pointers(context()) ? LLVM.PointerType(eltype(srcptr), as) :
LLVM.PointerType(as)
# the retargeted parameter type: for a pointer parameter (Case A) keep its pointee; for an
# integer parameter (Case B) use the pointee of the pointer its body reconstructs.
# integer parameter (Case B, typed-pointer shim) use the pointee of the pointer its body
# reconstructs (or the canonical `i8*` when it only forwards).
new_param_type(i, param_typ) =
new_addrspaces[i] < 0 ? param_typ :
int_ptr_types[i] !== nothing ?
Expand All @@ -682,9 +715,10 @@ function narrow_pointer_parameters!(mod::LLVM.Module, f::LLVM.Function,
if new_addrspaces[i] < 0
push!(new_args, parameters(new_f)[i])
elseif int_ptr_types[i] !== nothing
# Case B: rebuild the original integer as `ptrtoint(addrspacecast(param ->
# generic))`; the cloned body's `inttoptr` composes with it, and the following
# InferAddressSpaces run folds the whole chain to a specific-space load.
# Case B (typed-pointer shim): rebuild the original integer as `ptrtoint(
# addrspacecast(param -> generic))`. the cloned body either `inttoptr`s it (a
# leaf -- InferAddressSpaces then folds the whole chain to a specific-space load)
# or forwards it on (a delegator -- the next sweep narrows the callee in turn).
gen = addrspacecast!(builder, parameters(new_f)[i],
int_ptr_types[i]::LLVM.PointerType)
push!(new_args, ptrtoint!(builder, gen, param_typ))
Expand All @@ -704,8 +738,8 @@ function narrow_pointer_parameters!(mod::LLVM.Module, f::LLVM.Function,

# `clone_into!` copies a parameter's attributes only when it maps to a new argument; the
# retargeted ones map to the entry cast instead, so theirs are dropped. Reattach them for
# Case A (still valid on the narrowed pointer). Skip Case B: those were integer attributes
# (e.g. `zeroext`) that are invalid on the now-pointer parameter.
# Case A (still valid on the narrowed pointer). Skip Case B (typed-pointer shim): those were
# integer attributes (e.g. `zeroext`) that are invalid on the now-pointer parameter.
for i in 1:length(new_addrspaces)
(new_addrspaces[i] >= 0 && int_ptr_types[i] === nothing) || continue
for attr in collect(parameter_attributes(f, i))
Expand Down
94 changes: 82 additions & 12 deletions test/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,14 @@ end
@test (verify(mod); true)
end

# typed-pointer form: a `Ptr` argument crosses the boundary as an integer, so the callee
# takes an integer it `inttoptr`s and the callers pass `ptrtoint(addrspacecast(<global> ->
# generic))`. the integer parameter is narrowed to a constant-space pointer just the same,
# and the call sites pass the bare global. (regression: the pass used to skip integer
# parameters, leaving a `ptrtoint(addrspacecast(...))` constant that the LLVM-16 Metal
# bitcode downgrade miscompiles -- device exceptions crashed on Julia 1.11.)
# typed-pointer shim form: a `Ptr` argument crosses the boundary as an integer, so the
# callee takes an integer it `inttoptr`s and the callers pass `ptrtoint(addrspacecast(
# <global> -> generic))`. the integer parameter is narrowed to a constant-space pointer just
# the same, and the call sites pass the bare global. (regression: the pass used to skip
# integer parameters, leaving a `ptrtoint(addrspacecast(...))` constant that the LLVM-16
# Metal bitcode downgrade miscompiles -- device exceptions crashed on Julia 1.11.) the shim
# only applies under typed pointers; under opaque pointers a `Ptr` stays a generic pointer
# (Case A), so an integer parameter is deliberately left alone.
Context() do ctx
i8 = LLVM.Int8Type()
i64 = LLVM.Int64Type()
Expand All @@ -365,16 +367,22 @@ end
end
end

@test GPUCompiler.propagate_argument_address_spaces!(mod)
param = parameters(function_type(functions(mod)["callee"]))[1]
@test param isa LLVM.PointerType && addrspace(param) == 2
@test all(c -> value_type(arguments(c)[1]) isa LLVM.PointerType &&
addrspace(value_type(arguments(c)[1])) == 2, calls_to(mod, "callee"))
if supports_typed_pointers()
@test GPUCompiler.propagate_argument_address_spaces!(mod)
param = parameters(function_type(functions(mod)["callee"]))[1]
@test param isa LLVM.PointerType && addrspace(param) == 2
@test all(c -> value_type(arguments(c)[1]) isa LLVM.PointerType &&
addrspace(value_type(arguments(c)[1])) == 2, calls_to(mod, "callee"))
else
@test !GPUCompiler.propagate_argument_address_spaces!(mod)
@test parameters(function_type(functions(mod)["callee"]))[1] == i64
end
@test (verify(mod); true)
end

# an integer parameter used as more than a pointer image (here also in arithmetic) is
# left alone: narrowing it would lose the integer's other uses
# left alone: narrowing it would lose the integer's other uses (and under opaque pointers
# the typed-pointer shim is off, so it is left alone regardless)
Context() do ctx
i8 = LLVM.Int8Type()
i64 = LLVM.Int64Type()
Expand Down Expand Up @@ -448,6 +456,68 @@ end
@test param_as("leaf") == 2
@test (verify(mod); true)
end

# typed-pointer shim: the delegation chain in its integer form (caller -> mid -> leaf, all
# taking a `Ptr` lowered to an integer). `mid` only *forwards* its integer to `leaf`, so the
# leaf-shaped "all uses are inttoptr" test rejects it; recognizing a forwarded integer as a
# pointer image lets the fixpoint narrow `mid` first (its caller passes the recognizable
# `ptrtoint(addrspacecast(<global> -> generic))`), which re-exposes that shape at the call to
# `leaf` so the next sweep narrows it too. (regression: with only the leaf form, `mid` never
# narrowed and the deduced-name read stayed a generic-space load -- device exceptions crashed
# on Julia 1.11 even though it worked under opaque pointers.)
Context() do ctx
i8 = LLVM.Int8Type()
i64 = LLVM.Int64Type()
mod = LLVM.Module("test")
int_ft = LLVM.FunctionType(i8, LLVM.LLVMType[i64])
param_ty(name) = parameters(function_type(functions(mod)[name]))[1]

# leaf: inttoptrs its integer parameter and loads through it
leaf = LLVM.Function(mod, "leaf", int_ft)
linkage!(leaf, LLVM.API.LLVMInternalLinkage)
@dispose builder=IRBuilder() begin
position!(builder, BasicBlock(leaf, "entry"))
p = inttoptr!(builder, parameters(leaf)[1], asptr(0))
ret!(builder, load!(builder, i8, p))
end

# mid: forwards its integer parameter on to leaf (no inttoptr of its own)
mid = LLVM.Function(mod, "mid", int_ft)
linkage!(mid, LLVM.API.LLVMInternalLinkage)
@dispose builder=IRBuilder() begin
position!(builder, BasicBlock(mid, "entry"))
ret!(builder, call!(builder, int_ft, leaf, [parameters(mid)[1]]))
end

# caller: passes a constant global (AS 2) as ptrtoint(addrspacecast(... -> generic))
g = GlobalVariable(mod, i8, "g", 2)
initializer!(g, ConstantInt(i8, 1)); constant!(g, true)
caller = LLVM.Function(mod, "caller", LLVM.FunctionType(i8, LLVM.LLVMType[]))
linkage!(caller, LLVM.API.LLVMInternalLinkage)
@dispose builder=IRBuilder() begin
position!(builder, BasicBlock(caller, "entry"))
arg = const_ptrtoint(const_addrspacecast(g, asptr(0)), i64)
ret!(builder, call!(builder, int_ft, mid, [arg]))
end

if supports_typed_pointers()
# a single sweep narrows only the forwarding `mid`; the fixpoint must reach `leaf`
@test GPUCompiler.propagate_argument_address_spaces_once!(mod)
@test param_ty("mid") isa LLVM.PointerType && addrspace(param_ty("mid")) == 2
@test param_ty("leaf") == i64

@test GPUCompiler.propagate_argument_address_spaces!(mod)
@test param_ty("leaf") isa LLVM.PointerType && addrspace(param_ty("leaf")) == 2
# callers end up passing the bare global, not a ptrtoint(addrspacecast(...))
@test all(c -> value_type(arguments(c)[1]) isa LLVM.PointerType &&
addrspace(value_type(arguments(c)[1])) == 2, calls_to(mod, "mid"))
else
# shim off under opaque pointers: both stay integers, nothing narrows
@test !GPUCompiler.propagate_argument_address_spaces!(mod)
@test param_ty("mid") == i64 && param_ty("leaf") == i64
end
@test (verify(mod); true)
end
end

end
Loading