diff --git a/src/metal.jl b/src/metal.jl index 44ebccb9..2db25522 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -466,6 +466,17 @@ end # `InferAddressSpaces` run folds the entry cast away. The source need not be a constant # global; any pointer with a known address space qualifies, so any back-end can run it. # +# With typed pointers (older LLVM) a `Ptr` argument is lowered to an integer rather than a +# pointer, so the same boundary crossing arrives as `ptrtoint(addrspacecast( -> +# generic))` at the call site and `inttoptr` in the body. Such an integer parameter is +# retargeted the same way: it becomes a pointer in the agreed space, the call sites pass the +# bare source pointer, and entry rebuilds the original integer as `ptrtoint(addrspacecast( +# param -> generic))`. The cloned body's `inttoptr` then composes with that and the same +# `InferAddressSpaces` run folds the whole chain to a specific-space load. (Without this the +# leftover `ptrtoint(addrspacecast(...))` constant feeds a generic-space load that, e.g., the +# LLVM-16 Metal bitcode downgrade miscompiles into an invalid metallib — JuliaGPU/Metal.jl +# device exceptions on Julia 1.11.) +# # Narrowing one function makes its body forward an `addrspacecast`-from-specific to the # functions it calls, exposing them in turn. We therefore iterate to a fixed point so a # constant reaches an arbitrarily deep callee (e.g. an exception reporter that delegates to @@ -486,6 +497,31 @@ function addrspacecast_to_generic_source(@nospecialize(v)) return src end +# The typed-pointer counterpart of `addrspacecast_to_generic_source`. With typed pointers a +# `Ptr` argument is lowered to an integer, so a specific-space pointer crossing a call boundary +# arrives as `ptrtoint(addrspacecast( -> generic))` rather than the +# bare cast. If `v` is that shape, return the specific-space source pointer; otherwise `nothing`. +function ptrtoint_of_generic_source(@nospecialize(v)) + (v isa LLVM.Instruction || v isa LLVM.ConstantExpr) || return nothing + opcode(v) == LLVM.API.LLVMPtrToInt || return nothing + return addrspacecast_to_generic_source(operands(v)[1]) +end + +# If integer parameter `arg` is consumed only by `inttoptr` (all uses agreeing on the result +# pointer type), return that type; otherwise `nothing`. Such a parameter is the integer image +# of a pointer that crossed the call boundary, and can be retargeted to a pointer the same way +# a generic pointer parameter is (see `propagate_argument_address_spaces!`, the integer case). +function integer_param_pointer_type(arg::LLVM.Argument) + ptrty = nothing + for use in uses(arg) + u = user(use) + (u isa LLVM.Instruction && opcode(u) == LLVM.API.LLVMIntToPtr) || return nothing + t = value_type(u) + ptrty === nothing ? (ptrty = t) : (ptrty == t || return nothing) + end + return ptrty +end + function propagate_argument_address_spaces!(mod::LLVM.Module) changed = false while propagate_argument_address_spaces_once!(mod) @@ -521,24 +557,39 @@ function propagate_argument_address_spaces_once!(mod::LLVM.Module) end (only_calls && !isempty(callsites)) || continue - # for each generic pointer parameter, find the address space its callers agree on + # for each narrowable parameter, find the address space its callers agree on. a + # generic pointer parameter is passed as `addrspacecast( -> generic)`; with + # typed pointers a `Ptr` argument is instead an integer passed as `ptrtoint` of that + # cast (`int_ptr_types[i]` records the pointer it is reconstructed to, marking Case B). new_addrspaces = fill(-1, length(param_types)) + int_ptr_types = Vector{Any}(nothing, length(param_types)) for (i, pty) in enumerate(param_types) - (pty isa LLVM.PointerType && addrspace(pty) == 0) || continue + extract = if pty isa LLVM.PointerType && addrspace(pty) == 0 + addrspacecast_to_generic_source + elseif pty isa LLVM.IntegerType && + integer_param_pointer_type(parameters(f)[i]) !== nothing + ptrtoint_of_generic_source + else + continue + end as = -1 for cs in callsites - src = addrspacecast_to_generic_source(arguments(cs)[i]) + src = extract(arguments(cs)[i]) if src === nothing as = -1; break end src_as = addrspace(value_type(src)) as == -1 ? (as = src_as) : (as == src_as || (as = -1; break)) end - as > 0 && (new_addrspaces[i] = as) + if as > 0 + new_addrspaces[i] = as + pty isa LLVM.IntegerType && + (int_ptr_types[i] = integer_param_pointer_type(parameters(f)[i])) + end end any(>=(0), new_addrspaces) || continue - narrow_pointer_parameters!(mod, f, new_addrspaces, callsites) + narrow_pointer_parameters!(mod, f, new_addrspaces, int_ptr_types, callsites) changed = true end return changed @@ -566,11 +617,22 @@ end # convention, operand bundles and attributes; replaces and erases the old call. function rewrite_narrowed_call!(builder::IRBuilder, cs::LLVM.CallInst, new_f::LLVM.Function, new_ft::LLVM.FunctionType, - new_addrspaces::Vector{Int}) + new_addrspaces::Vector{Int}, int_ptr_types::Vector{Any}) position!(builder, cs) - new_args = LLVM.Value[new_addrspaces[i] >= 0 ? - addrspacecast_to_generic_source(arg) : arg - for (i, arg) in enumerate(arguments(cs))] + new_param_types = parameters(new_ft) + new_args = LLVM.Value[] + for (i, arg) in enumerate(arguments(cs)) + if new_addrspaces[i] < 0 + push!(new_args, arg) + elseif int_ptr_types[i] !== nothing + # Case B: strip the `ptrtoint` then the cast, and bitcast the bare specific-space + # pointer to the retargeted parameter's pointer type + src = addrspacecast_to_generic_source(operands(arg)[1]) + push!(new_args, bitcast!(builder, src, new_param_types[i])) + else + push!(new_args, addrspacecast_to_generic_source(arg)) + end + end new_call = call!(builder, new_ft, new_f, new_args, operand_bundles(cs)) callconv!(new_call, callconv(cs)) copy_callsite_attributes!(new_call, cs) @@ -584,14 +646,22 @@ end # back to generic on entry so the cloned body is unchanged. Rewrite `callsites` to pass the # un-casted source value for each retargeted argument; recursive self-calls are handled too. function narrow_pointer_parameters!(mod::LLVM.Module, f::LLVM.Function, - new_addrspaces::Vector{Int}, callsites) + new_addrspaces::Vector{Int}, int_ptr_types::Vector{Any}, + callsites) ft = function_type(f) - retarget(pty::LLVM.PointerType, as::Integer) = - supports_typed_pointers(context()) ? LLVM.PointerType(eltype(pty), as) : + # retarget a pointer to address space `as`, taking its pointee from `srcptr` (only needed + # for typed pointers; `eltype` is invalid on opaque ones, so keep it lazy) + retarget(as::Integer, srcptr::LLVM.PointerType) = + supports_typed_pointers(context()) ? LLVM.PointerType(eltype(srcptr), as) : LLVM.PointerType(as) - new_types = LLVM.LLVMType[new_addrspaces[i] >= 0 ? - retarget(param_typ::LLVM.PointerType, new_addrspaces[i]) : - param_typ + # the retargeted parameter type: for a pointer parameter (Case A) keep its pointee; for an + # integer parameter (Case B) use the pointee of the pointer its body reconstructs. + new_param_type(i, param_typ) = + new_addrspaces[i] < 0 ? param_typ : + int_ptr_types[i] !== nothing ? + retarget(new_addrspaces[i], int_ptr_types[i]::LLVM.PointerType) : + retarget(new_addrspaces[i], param_typ::LLVM.PointerType) + new_types = LLVM.LLVMType[new_param_type(i, param_typ) for (i, param_typ) in enumerate(parameters(ft))] new_ft = LLVM.FunctionType(return_type(ft), new_types) @@ -609,10 +679,17 @@ function narrow_pointer_parameters!(mod::LLVM.Module, f::LLVM.Function, position!(builder, entry) new_args = LLVM.Value[] for (i, param_typ) in enumerate(parameters(ft)) - if new_addrspaces[i] >= 0 - push!(new_args, addrspacecast!(builder, parameters(new_f)[i], param_typ)) - else + if new_addrspaces[i] < 0 push!(new_args, parameters(new_f)[i]) + elseif int_ptr_types[i] !== nothing + # Case B: rebuild the original integer as `ptrtoint(addrspacecast(param -> + # generic))`; the cloned body's `inttoptr` composes with it, and the following + # InferAddressSpaces run folds the whole chain to a specific-space load. + gen = addrspacecast!(builder, parameters(new_f)[i], + int_ptr_types[i]::LLVM.PointerType) + push!(new_args, ptrtoint!(builder, gen, param_typ)) + else + push!(new_args, addrspacecast!(builder, parameters(new_f)[i], param_typ)) end end @@ -626,10 +703,11 @@ function narrow_pointer_parameters!(mod::LLVM.Module, f::LLVM.Function, end # `clone_into!` copies a parameter's attributes only when it maps to a new argument; the - # retargeted ones map to the entry addrspacecast instead, so theirs are dropped. Reattach - # them; they stay valid on the narrowed pointer, and non-retargeted params keep theirs. + # retargeted ones map to the entry cast instead, so theirs are dropped. Reattach them for + # Case A (still valid on the narrowed pointer). Skip Case B: those were integer attributes + # (e.g. `zeroext`) that are invalid on the now-pointer parameter. for i in 1:length(new_addrspaces) - new_addrspaces[i] >= 0 || continue + (new_addrspaces[i] >= 0 && int_ptr_types[i] === nothing) || continue for attr in collect(parameter_attributes(f, i)) push!(parameter_attributes(new_f, i), attr) end @@ -646,10 +724,10 @@ function narrow_pointer_parameters!(mod::LLVM.Module, f::LLVM.Function, # rewrite call sites to pass the un-casted source value for each retargeted argument @dispose builder=IRBuilder() begin for cs in callsites - rewrite_narrowed_call!(builder, cs, new_f, new_ft, new_addrspaces) + rewrite_narrowed_call!(builder, cs, new_f, new_ft, new_addrspaces, int_ptr_types) end for cs in self_calls - rewrite_narrowed_call!(builder, cs, new_f, new_ft, new_addrspaces) + rewrite_narrowed_call!(builder, cs, new_f, new_ft, new_addrspaces, int_ptr_types) end end diff --git a/test/metal.jl b/test/metal.jl index ae837ccd..43c6a575 100644 --- a/test/metal.jl +++ b/test/metal.jl @@ -335,6 +335,75 @@ end @test (verify(mod); true) end + # typed-pointer form: a `Ptr` argument crosses the boundary as an integer, so the callee + # takes an integer it `inttoptr`s and the callers pass `ptrtoint(addrspacecast( -> + # generic))`. the integer parameter is narrowed to a constant-space pointer just the same, + # and the call sites pass the bare global. (regression: the pass used to skip integer + # parameters, leaving a `ptrtoint(addrspacecast(...))` constant that the LLVM-16 Metal + # bitcode downgrade miscompiles -- device exceptions crashed on Julia 1.11.) + Context() do ctx + i8 = LLVM.Int8Type() + i64 = LLVM.Int64Type() + mod = LLVM.Module("test") + callee_ft = LLVM.FunctionType(i8, LLVM.LLVMType[i64]) + callee = LLVM.Function(mod, "callee", callee_ft) + linkage!(callee, LLVM.API.LLVMInternalLinkage) + @dispose builder=IRBuilder() begin + position!(builder, BasicBlock(callee, "entry")) + p = inttoptr!(builder, parameters(callee)[1], asptr(0)) + ret!(builder, load!(builder, i8, p)) + end + for n in 1:2 + g = GlobalVariable(mod, i8, "g$n", 2) + initializer!(g, ConstantInt(i8, n)); constant!(g, true) + caller = LLVM.Function(mod, "caller$n", LLVM.FunctionType(i8, LLVM.LLVMType[])) + linkage!(caller, LLVM.API.LLVMInternalLinkage) + @dispose builder=IRBuilder() begin + position!(builder, BasicBlock(caller, "entry")) + arg = const_ptrtoint(const_addrspacecast(g, asptr(0)), i64) + ret!(builder, call!(builder, callee_ft, callee, [arg])) + end + end + + @test GPUCompiler.propagate_argument_address_spaces!(mod) + param = parameters(function_type(functions(mod)["callee"]))[1] + @test param isa LLVM.PointerType && addrspace(param) == 2 + @test all(c -> value_type(arguments(c)[1]) isa LLVM.PointerType && + addrspace(value_type(arguments(c)[1])) == 2, calls_to(mod, "callee")) + @test (verify(mod); true) + end + + # an integer parameter used as more than a pointer image (here also in arithmetic) is + # left alone: narrowing it would lose the integer's other uses + Context() do ctx + i8 = LLVM.Int8Type() + i64 = LLVM.Int64Type() + mod = LLVM.Module("test") + callee_ft = LLVM.FunctionType(i8, LLVM.LLVMType[i64]) + callee = LLVM.Function(mod, "callee", callee_ft) + linkage!(callee, LLVM.API.LLVMInternalLinkage) + @dispose builder=IRBuilder() begin + position!(builder, BasicBlock(callee, "entry")) + p = inttoptr!(builder, parameters(callee)[1], asptr(0)) + v = load!(builder, i8, p) + # a second, non-`inttoptr` use of the integer parameter + extra = trunc!(builder, add!(builder, parameters(callee)[1], parameters(callee)[1]), i8) + ret!(builder, add!(builder, v, extra)) + end + g = GlobalVariable(mod, i8, "g", 2) + initializer!(g, ConstantInt(i8, 1)); constant!(g, true) + caller = LLVM.Function(mod, "caller", LLVM.FunctionType(i8, LLVM.LLVMType[])) + linkage!(caller, LLVM.API.LLVMInternalLinkage) + @dispose builder=IRBuilder() begin + position!(builder, BasicBlock(caller, "entry")) + arg = const_ptrtoint(const_addrspacecast(g, asptr(0)), i64) + ret!(builder, call!(builder, callee_ft, callee, [arg])) + end + + @test !GPUCompiler.propagate_argument_address_spaces!(mod) + @test parameters(function_type(functions(mod)["callee"]))[1] == i64 + end + # a two-level delegation chain (caller -> mid -> leaf) needs the fixpoint: one sweep # narrows `mid` (its caller passes a constant global), which only then exposes `leaf`, # since `mid` now forwards an addrspacecast-from-constant. iterate until both narrow.