diff --git a/Project.toml b/Project.toml index a448cb24..b7c25593 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "GPUCompiler" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "1.15.1" +version = "1.15.2" authors = ["Tim Besard "] [workspace] diff --git a/src/metal.jl b/src/metal.jl index 2db25522..e217cbda 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -466,23 +466,32 @@ end # `InferAddressSpaces` run folds the entry cast away. The source need not be a constant # global; any pointer with a known address space qualifies, so any back-end can run it. # -# With typed pointers (older LLVM) a `Ptr` argument is lowered to an integer rather than a -# pointer, so the same boundary crossing arrives as `ptrtoint(addrspacecast( -> -# generic))` at the call site and `inttoptr` in the body. Such an integer parameter is -# retargeted the same way: it becomes a pointer in the agreed space, the call sites pass the -# bare source pointer, and entry rebuilds the original integer as `ptrtoint(addrspacecast( -# param -> generic))`. The cloned body's `inttoptr` then composes with that and the same -# `InferAddressSpaces` run folds the whole chain to a specific-space load. (Without this the -# leftover `ptrtoint(addrspacecast(...))` constant feeds a generic-space load that, e.g., the +# TYPED-POINTER SHIM (Julia <= 1.11) -- gated on `supports_typed_pointers(context())` and +# removable, along with everything else tagged "typed-pointer shim", once 1.12 is the minimum. +# With typed pointers a `Ptr` argument is lowered to an integer rather than a pointer, so the +# same boundary crossing arrives as `ptrtoint(addrspacecast( -> generic))` at the call +# site, and the body either `inttoptr`s the integer (a leaf reporter that dereferences it) or +# forwards it on to a further call (a reporter that delegates, e.g. `report_exception_name` -> +# `report_exception`). Either way the integer is the image of a specific-space pointer, and is +# retargeted the same way as a generic pointer parameter: it becomes a pointer in the agreed +# space, the call sites pass the bare source pointer, and entry rebuilds the original integer as +# `ptrtoint(addrspacecast(param -> generic))`. This is sound for the same reason as the pointer +# case -- it only relocates a value-preserving cast/round-trip across the boundary, so every use +# (deref or forward) sees the bit-identical integer. A leaf's cloned `inttoptr` then composes +# with the rebuilt integer and `InferAddressSpaces` folds the chain to a specific-space load; a +# delegating function instead re-exposes the `ptrtoint(addrspacecast(...))` shape at its +# forwarded call, which the next sweep narrows in turn (see the fixed point below). (Without +# this the leftover `ptrtoint(addrspacecast(...))` feeds a generic-space load that, e.g., the # LLVM-16 Metal bitcode downgrade miscompiles into an invalid metallib — JuliaGPU/Metal.jl # device exceptions on Julia 1.11.) # -# Narrowing one function makes its body forward an `addrspacecast`-from-specific to the -# functions it calls, exposing them in turn. We therefore iterate to a fixed point so a -# constant reaches an arbitrarily deep callee (e.g. an exception reporter that delegates to -# another) regardless of the order functions are visited in. This terminates: each sweep -# that changes anything strictly reduces the number of generic pointer parameters in the -# module, and narrowing never introduces a new one. +# Narrowing one function makes its body forward an `addrspacecast`-from-specific (Case A) or, in +# the shim, a `ptrtoint(addrspacecast(...))` (Case B) to the functions it calls, exposing them +# in turn. We therefore iterate to a fixed point so a constant reaches an arbitrarily deep callee +# (e.g. an exception reporter that delegates to another) regardless of the order functions are +# visited in. This terminates: each sweep that changes anything strictly reduces the number of +# narrowable (generic-pointer or, in the shim, integer-image) parameters in the module, and +# narrowing never introduces a new one. # If `v` is an `addrspacecast` (instruction or constant expression) of a pointer from a # specific (non-generic) address space to the generic one, return that source pointer; @@ -497,29 +506,48 @@ function addrspacecast_to_generic_source(@nospecialize(v)) return src end -# The typed-pointer counterpart of `addrspacecast_to_generic_source`. With typed pointers a -# `Ptr` argument is lowered to an integer, so a specific-space pointer crossing a call boundary -# arrives as `ptrtoint(addrspacecast( -> generic))` rather than the -# bare cast. If `v` is that shape, return the specific-space source pointer; otherwise `nothing`. +# Typed-pointer shim (Julia <= 1.11) -- remove once 1.12 is the minimum. The typed-pointer +# counterpart of `addrspacecast_to_generic_source`: with typed pointers a `Ptr` argument is +# lowered to an integer, so a specific-space pointer crossing a call boundary arrives as +# `ptrtoint(addrspacecast( -> generic))` rather than the bare cast. If +# `v` is that shape, return the specific-space source pointer; otherwise `nothing`. function ptrtoint_of_generic_source(@nospecialize(v)) (v isa LLVM.Instruction || v isa LLVM.ConstantExpr) || return nothing opcode(v) == LLVM.API.LLVMPtrToInt || return nothing return addrspacecast_to_generic_source(operands(v)[1]) end -# If integer parameter `arg` is consumed only by `inttoptr` (all uses agreeing on the result -# pointer type), return that type; otherwise `nothing`. Such a parameter is the integer image -# of a pointer that crossed the call boundary, and can be retargeted to a pointer the same way -# a generic pointer parameter is (see `propagate_argument_address_spaces!`, the integer case). -function integer_param_pointer_type(arg::LLVM.Argument) +# Typed-pointer shim (Julia <= 1.11) -- remove once 1.12 is the minimum. Classify integer +# parameter `arg` as the integer image of a pointer that crossed a call boundary, returning the +# generic pointer type to reconstruct it to (so it can be retargeted like a generic pointer +# parameter; see `propagate_argument_address_spaces!`), or `nothing` if it is not safely a +# pointer image. It qualifies when every use is either +# * an `inttoptr` to the generic space -- the leaf shape, where the body dereferences it (all +# such uses must agree on the result type, which pins the reconstructed pointee); or +# * a call argument -- the delegation shape, where the body forwards it on unchanged. +# A purely-forwarding parameter has no `inttoptr` to pin the pointee, so a canonical generic +# `i8*` is used: every boundary is a `bitcast`/`ptrtoint`, so the choice only affects the +# bridging casts, not the reconstructed value. Any other use (arithmetic, comparison, storing +# the integer, ...) means it is genuinely an integer, so it is left alone -- narrowing it would +# be value-preserving but pointless, and we have no pointee to reconstruct to. +function integer_param_pointer_image_type(arg::LLVM.Argument) ptrty = nothing + forwarded = false for use in uses(arg) u = user(use) - (u isa LLVM.Instruction && opcode(u) == LLVM.API.LLVMIntToPtr) || return nothing - t = value_type(u) - ptrty === nothing ? (ptrty = t) : (ptrty == t || return nothing) + if u isa LLVM.Instruction && opcode(u) == LLVM.API.LLVMIntToPtr + t = value_type(u) + (t isa LLVM.PointerType && addrspace(t) == 0) || return nothing + ptrty === nothing ? (ptrty = t) : (ptrty == t || return nothing) + elseif u isa LLVM.CallInst + forwarded = true + else + return nothing + end end - return ptrty + ptrty !== nothing && return ptrty + forwarded && return LLVM.PointerType(LLVM.Int8Type()) + return nothing end function propagate_argument_address_spaces!(mod::LLVM.Module) @@ -566,8 +594,10 @@ function propagate_argument_address_spaces_once!(mod::LLVM.Module) for (i, pty) in enumerate(param_types) extract = if pty isa LLVM.PointerType && addrspace(pty) == 0 addrspacecast_to_generic_source - elseif pty isa LLVM.IntegerType && - integer_param_pointer_type(parameters(f)[i]) !== nothing + # typed-pointer shim (Julia <= 1.11) -- remove once 1.12 is the minimum. under opaque + # pointers a `Ptr` stays a generic pointer (Case A above), so this only matters here. + elseif supports_typed_pointers(context()) && pty isa LLVM.IntegerType && + integer_param_pointer_image_type(parameters(f)[i]) !== nothing ptrtoint_of_generic_source else continue @@ -583,8 +613,10 @@ function propagate_argument_address_spaces_once!(mod::LLVM.Module) end if as > 0 new_addrspaces[i] = as + # typed-pointer shim (Julia <= 1.11): an integer candidate only gets here under + # `supports_typed_pointers`, so record the pointer it reconstructs to (Case B). pty isa LLVM.IntegerType && - (int_ptr_types[i] = integer_param_pointer_type(parameters(f)[i])) + (int_ptr_types[i] = integer_param_pointer_image_type(parameters(f)[i])) end end any(>=(0), new_addrspaces) || continue @@ -625,8 +657,8 @@ function rewrite_narrowed_call!(builder::IRBuilder, cs::LLVM.CallInst, if new_addrspaces[i] < 0 push!(new_args, arg) elseif int_ptr_types[i] !== nothing - # Case B: strip the `ptrtoint` then the cast, and bitcast the bare specific-space - # pointer to the retargeted parameter's pointer type + # Case B (typed-pointer shim): strip the `ptrtoint` then the cast, and bitcast the + # bare specific-space pointer to the retargeted parameter's pointer type src = addrspacecast_to_generic_source(operands(arg)[1]) push!(new_args, bitcast!(builder, src, new_param_types[i])) else @@ -655,7 +687,8 @@ function narrow_pointer_parameters!(mod::LLVM.Module, f::LLVM.Function, supports_typed_pointers(context()) ? LLVM.PointerType(eltype(srcptr), as) : LLVM.PointerType(as) # the retargeted parameter type: for a pointer parameter (Case A) keep its pointee; for an - # integer parameter (Case B) use the pointee of the pointer its body reconstructs. + # integer parameter (Case B, typed-pointer shim) use the pointee of the pointer its body + # reconstructs (or the canonical `i8*` when it only forwards). new_param_type(i, param_typ) = new_addrspaces[i] < 0 ? param_typ : int_ptr_types[i] !== nothing ? @@ -682,9 +715,10 @@ function narrow_pointer_parameters!(mod::LLVM.Module, f::LLVM.Function, if new_addrspaces[i] < 0 push!(new_args, parameters(new_f)[i]) elseif int_ptr_types[i] !== nothing - # Case B: rebuild the original integer as `ptrtoint(addrspacecast(param -> - # generic))`; the cloned body's `inttoptr` composes with it, and the following - # InferAddressSpaces run folds the whole chain to a specific-space load. + # Case B (typed-pointer shim): rebuild the original integer as `ptrtoint( + # addrspacecast(param -> generic))`. the cloned body either `inttoptr`s it (a + # leaf -- InferAddressSpaces then folds the whole chain to a specific-space load) + # or forwards it on (a delegator -- the next sweep narrows the callee in turn). gen = addrspacecast!(builder, parameters(new_f)[i], int_ptr_types[i]::LLVM.PointerType) push!(new_args, ptrtoint!(builder, gen, param_typ)) @@ -704,8 +738,8 @@ function narrow_pointer_parameters!(mod::LLVM.Module, f::LLVM.Function, # `clone_into!` copies a parameter's attributes only when it maps to a new argument; the # retargeted ones map to the entry cast instead, so theirs are dropped. Reattach them for - # Case A (still valid on the narrowed pointer). Skip Case B: those were integer attributes - # (e.g. `zeroext`) that are invalid on the now-pointer parameter. + # Case A (still valid on the narrowed pointer). Skip Case B (typed-pointer shim): those were + # integer attributes (e.g. `zeroext`) that are invalid on the now-pointer parameter. for i in 1:length(new_addrspaces) (new_addrspaces[i] >= 0 && int_ptr_types[i] === nothing) || continue for attr in collect(parameter_attributes(f, i)) diff --git a/test/metal.jl b/test/metal.jl index 43c6a575..396fadd9 100644 --- a/test/metal.jl +++ b/test/metal.jl @@ -335,12 +335,14 @@ end @test (verify(mod); true) end - # typed-pointer form: a `Ptr` argument crosses the boundary as an integer, so the callee - # takes an integer it `inttoptr`s and the callers pass `ptrtoint(addrspacecast( -> - # generic))`. the integer parameter is narrowed to a constant-space pointer just the same, - # and the call sites pass the bare global. (regression: the pass used to skip integer - # parameters, leaving a `ptrtoint(addrspacecast(...))` constant that the LLVM-16 Metal - # bitcode downgrade miscompiles -- device exceptions crashed on Julia 1.11.) + # typed-pointer shim form: a `Ptr` argument crosses the boundary as an integer, so the + # callee takes an integer it `inttoptr`s and the callers pass `ptrtoint(addrspacecast( + # -> generic))`. the integer parameter is narrowed to a constant-space pointer just + # the same, and the call sites pass the bare global. (regression: the pass used to skip + # integer parameters, leaving a `ptrtoint(addrspacecast(...))` constant that the LLVM-16 + # Metal bitcode downgrade miscompiles -- device exceptions crashed on Julia 1.11.) the shim + # only applies under typed pointers; under opaque pointers a `Ptr` stays a generic pointer + # (Case A), so an integer parameter is deliberately left alone. Context() do ctx i8 = LLVM.Int8Type() i64 = LLVM.Int64Type() @@ -365,16 +367,22 @@ end end end - @test GPUCompiler.propagate_argument_address_spaces!(mod) - param = parameters(function_type(functions(mod)["callee"]))[1] - @test param isa LLVM.PointerType && addrspace(param) == 2 - @test all(c -> value_type(arguments(c)[1]) isa LLVM.PointerType && - addrspace(value_type(arguments(c)[1])) == 2, calls_to(mod, "callee")) + if supports_typed_pointers() + @test GPUCompiler.propagate_argument_address_spaces!(mod) + param = parameters(function_type(functions(mod)["callee"]))[1] + @test param isa LLVM.PointerType && addrspace(param) == 2 + @test all(c -> value_type(arguments(c)[1]) isa LLVM.PointerType && + addrspace(value_type(arguments(c)[1])) == 2, calls_to(mod, "callee")) + else + @test !GPUCompiler.propagate_argument_address_spaces!(mod) + @test parameters(function_type(functions(mod)["callee"]))[1] == i64 + end @test (verify(mod); true) end # an integer parameter used as more than a pointer image (here also in arithmetic) is - # left alone: narrowing it would lose the integer's other uses + # left alone: narrowing it would lose the integer's other uses (and under opaque pointers + # the typed-pointer shim is off, so it is left alone regardless) Context() do ctx i8 = LLVM.Int8Type() i64 = LLVM.Int64Type() @@ -448,6 +456,68 @@ end @test param_as("leaf") == 2 @test (verify(mod); true) end + + # typed-pointer shim: the delegation chain in its integer form (caller -> mid -> leaf, all + # taking a `Ptr` lowered to an integer). `mid` only *forwards* its integer to `leaf`, so the + # leaf-shaped "all uses are inttoptr" test rejects it; recognizing a forwarded integer as a + # pointer image lets the fixpoint narrow `mid` first (its caller passes the recognizable + # `ptrtoint(addrspacecast( -> generic))`), which re-exposes that shape at the call to + # `leaf` so the next sweep narrows it too. (regression: with only the leaf form, `mid` never + # narrowed and the deduced-name read stayed a generic-space load -- device exceptions crashed + # on Julia 1.11 even though it worked under opaque pointers.) + Context() do ctx + i8 = LLVM.Int8Type() + i64 = LLVM.Int64Type() + mod = LLVM.Module("test") + int_ft = LLVM.FunctionType(i8, LLVM.LLVMType[i64]) + param_ty(name) = parameters(function_type(functions(mod)[name]))[1] + + # leaf: inttoptrs its integer parameter and loads through it + leaf = LLVM.Function(mod, "leaf", int_ft) + linkage!(leaf, LLVM.API.LLVMInternalLinkage) + @dispose builder=IRBuilder() begin + position!(builder, BasicBlock(leaf, "entry")) + p = inttoptr!(builder, parameters(leaf)[1], asptr(0)) + ret!(builder, load!(builder, i8, p)) + end + + # mid: forwards its integer parameter on to leaf (no inttoptr of its own) + mid = LLVM.Function(mod, "mid", int_ft) + linkage!(mid, LLVM.API.LLVMInternalLinkage) + @dispose builder=IRBuilder() begin + position!(builder, BasicBlock(mid, "entry")) + ret!(builder, call!(builder, int_ft, leaf, [parameters(mid)[1]])) + end + + # caller: passes a constant global (AS 2) as ptrtoint(addrspacecast(... -> generic)) + g = GlobalVariable(mod, i8, "g", 2) + initializer!(g, ConstantInt(i8, 1)); constant!(g, true) + caller = LLVM.Function(mod, "caller", LLVM.FunctionType(i8, LLVM.LLVMType[])) + linkage!(caller, LLVM.API.LLVMInternalLinkage) + @dispose builder=IRBuilder() begin + position!(builder, BasicBlock(caller, "entry")) + arg = const_ptrtoint(const_addrspacecast(g, asptr(0)), i64) + ret!(builder, call!(builder, int_ft, mid, [arg])) + end + + if supports_typed_pointers() + # a single sweep narrows only the forwarding `mid`; the fixpoint must reach `leaf` + @test GPUCompiler.propagate_argument_address_spaces_once!(mod) + @test param_ty("mid") isa LLVM.PointerType && addrspace(param_ty("mid")) == 2 + @test param_ty("leaf") == i64 + + @test GPUCompiler.propagate_argument_address_spaces!(mod) + @test param_ty("leaf") isa LLVM.PointerType && addrspace(param_ty("leaf")) == 2 + # callers end up passing the bare global, not a ptrtoint(addrspacecast(...)) + @test all(c -> value_type(arguments(c)[1]) isa LLVM.PointerType && + addrspace(value_type(arguments(c)[1])) == 2, calls_to(mod, "mid")) + else + # shim off under opaque pointers: both stay integers, nothing narrows + @test !GPUCompiler.propagate_argument_address_spaces!(mod) + @test param_ty("mid") == i64 && param_ty("leaf") == i64 + end + @test (verify(mod); true) + end end end