From 26ea6329e90cd26120be1316179278f16a918945 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 1 Jan 2026 19:00:54 -0400 Subject: [PATCH] Faster matmul --- Project.toml | 2 +- src/GPUArrays.jl | 1 + src/host/linalg.jl | 77 +++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 78 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index bb6e0219..9f04923e 100644 --- a/Project.toml +++ b/Project.toml @@ -29,7 +29,7 @@ JLD2Ext = "JLD2" Adapt = "4.0" GPUArraysCore = "= 0.2.0" JLD2 = "0.4, 0.5, 0.6" -KernelAbstractions = "0.9.28, 0.10" +KernelAbstractions = "0.10" LLVM = "3.9, 4, 5, 6, 7, 8, 9" LinearAlgebra = "1" Printf = "1" diff --git a/src/GPUArrays.jl b/src/GPUArrays.jl index 17a4d897..ffde2025 100644 --- a/src/GPUArrays.jl +++ b/src/GPUArrays.jl @@ -16,6 +16,7 @@ using Reexport @reexport using GPUArraysCore using KernelAbstractions +import KernelAbstractions.KernelIntrinsics as KI # device functionality include("device/abstractarray.jl") diff --git a/src/host/linalg.jl b/src/host/linalg.jl index a46b85d2..532c506a 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -436,7 +436,6 @@ function LinearAlgebra.ldiv!(B::AbstractGPUVecOrMat, B end - ## matrix multiplication # legacy method generic_matmatmul!(C::AbstractArray, A::AbstractArray, B::AbstractArray, a::Number, b::Number) = @@ -470,6 +469,82 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac C end +# Higher performance kernel for Matrices +# XXX: figure out how to do dynamically +MAX_TILE_DIM = 16 +function coalesced_matmul_kernel!( + output, input1, input2, N, Q, M, + add::MulAddMul, ::Val{TILE_DIM}, ::Val{BANK} = Val(1), + ) where {TILE_DIM, BANK} + R = eltype(output) + + grow, gcol, _ = KI.get_group_id() + tile_row, tile_col, _ = KI.get_local_id() + + I = (grow - 1) * TILE_DIM + tile_row + J = (gcol - 1) * TILE_DIM + tile_col + + # +1 to avoid bank conflicts on shared memory + tile1 = KI.localmemory(R, (TILE_DIM + BANK, TILE_DIM)) + tile2 = KI.localmemory(R, (TILE_DIM + BANK, TILE_DIM)) + + # private variable for tile output + outval = -zero(R) + + # number of tiles depends on inner dimension + NUM_TILES = div(Q + TILE_DIM - 1, TILE_DIM) + + # loop over all tiles needed for this calculation + for t in 0:(NUM_TILES - 1) + # load inputs into tiles, with bounds checking for non-square matrices + if I <= N && t * TILE_DIM + tile_col <= Q + @inbounds tile1[tile_row, tile_col] = input1[I, t * TILE_DIM + tile_col] + else + @inbounds tile1[tile_row, tile_col] = zero(R) + end + if J <= M && t * TILE_DIM + tile_row <= Q + @inbounds tile2[tile_row, tile_col] = input2[t * TILE_DIM + tile_row, J] + else + @inbounds tile2[tile_row, tile_col] = zero(R) + end + + # wait for all tiles to be loaded + KI.barrier() + + # calculate value of spot in output, use temporary value to allow for vectorization + out = zero(R) + @simd for k in 1:TILE_DIM + @inbounds out += tile1[tile_row, k] * tile2[k, tile_col] + end + outval += out + + KI.barrier() + end + + # save if inbounds + if I <= N && J <= M + @inbounds output[I, J] = add(outval, output[I, J]) + end + return +end +function generic_matmatmul!(C::AbstractGPUMatrix{R}, A::AbstractGPUMatrix{T}, B::AbstractGPUMatrix{S}, add::MulAddMul) where {T<:Number,S<:Number,R<:Number} + N = size(A,1) + Q = size(A,2) + M = size(B,2) + if Q != size(B,1) + throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))")) + end + if size(C,1) != N || size(C,2) != M + throw(DimensionMismatch("result C has dimensions $(size(C)), needs $((N,M))")) + end + if isempty(A) || isempty(B) + return fill!(C, zero(R)) + end + + KI.@kernel get_backend(C) workgroupsize=(MAX_TILE_DIM, MAX_TILE_DIM) numworkgroups=map(x -> cld(x, MAX_TILE_DIM), size(C)) coalesced_matmul_kernel!(C, A, B, N, Q, M, add, Val(MAX_TILE_DIM)) + C +end + @static if !isdefined(LinearAlgebra, Symbol("@stable_muladdmul")) # @stable_muladdmul was added in 1.12 function LinearAlgebra.generic_matvecmul!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, _add::MulAddMul = MulAddMul()) generic_matmatmul!(C, wrap(A, tA), B, _add)