Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 84 additions & 1 deletion src/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L

add_argument_metadata!(job, mod, entry)

add_globals_metadata!(job, mod)

add_module_metadata!(job, mod)
end

Expand Down Expand Up @@ -910,6 +912,87 @@ function argument_type_name(typ)
end
end

# global metadata generation
#
# module metadata is used to identify global buffers that are used as kernel arguments.
function add_globals_metadata!(@nospecialize(job::CompilerJob), mod::LLVM.Module)
# Iterate through arguments and create metadata for them
globs = globals(mod)

i = 1
for gv in globs
gv_typ = global_value_type(gv)
(isconstant(gv) && gv_typ isa LLVM.PointerType && addrspace(gv_typ) == 3) || continue
# if job.config.optimize
# @assert parameters(entry_ft)[arg.idx] isa LLVM.PointerType
# else
# parameters(entry_ft)[arg.idx] isa LLVM.PointerType || continue
# end

# # NOTE: we emit the bare minimum of argument metadata to support
# # bindless argument encoding. Actually using the argument encoder
# # APIs (deprecated in Metal 3) turned out too difficult, given the
# # undocumented nature of the argument metadata, and the complex
# # arguments we encounter with typical Julia kernels.
global_infos = Metadata[]

push!(global_infos, MDString("air.global_binding"))
push!(global_infos, Metadata(gv))

md = Metadata[]

# argument index
push!(md, Metadata(ConstantInt(Int32(-1))))

push!(md, MDString("air.buffer"))

push!(md, MDString("air.location_index"))
push!(md, Metadata(ConstantInt(Int32(i-1))))

# XXX: unknown
push!(md, Metadata(ConstantInt(Int32(1))))

push!(md, MDString("air.read_write")) # TODO: Check for const array

push!(md, MDString("air.address_space"))
push!(md, Metadata(ConstantInt(Int32(addrspace(global_value_type(gv))))))

# val_type = global_value_type(gv)
# val_type = if value_type(gv) <: Core.LLVMPtr
# arg.typ.parameters[1]
# else
# arg.typ
# end

# @show gv_typ
# @show isconstant(gv)
# @show isconstant(gv_typ)
# @show Int32(alignment(gv))

push!(md, MDString("air.arg_type_size"))
push!(md, Metadata(ConstantInt(Int32(4))))

push!(md, MDString("air.arg_type_align_size"))
push!(md, Metadata(ConstantInt(Int32(alignment(gv)))))

push!(md, MDString("air.arg_type_name"))
# XXX: Figure out how to get type
push!(md, MDString("float"))
# push!(md, MDString(repr(arg.typ)))

push!(md, MDString("air.arg_name"))
push!(md, MDString(String(LLVM.name(gv))))

push!(global_infos, MDNode(md))

push!(metadata(mod)["air.global_bindings"], MDNode(global_infos))

i += 1
end

return
end

# argument metadata generation
#
# module metadata is used to identify buffers that are passed as kernel arguments.
Expand All @@ -925,7 +1008,7 @@ function add_argument_metadata!(@nospecialize(job::CompilerJob), mod::LLVM.Modul
args = classify_arguments(job, entry_ft; post_optimization=job.config.optimize)
i = 1
for arg in args
arg.idx === nothing && continue
arg.idx === nothing && continue
if job.config.optimize
@assert parameters(entry_ft)[arg.idx] isa LLVM.PointerType
else
Expand Down
Loading