From c15b97fcec71a2d0df21d5db38fb5ce2bd2a6e72 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 2 Jun 2026 13:28:31 -0400 Subject: [PATCH 1/9] Support and test mul(C, A, B, alpha, beta) for triangular mats --- src/host/linalg.jl | 33 ++++++++++++++++++++++++++++++++- test/testsuite/linalg.jl | 13 +++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index a46b85d2..eb630902 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -1,6 +1,6 @@ # integration with LinearAlgebra stdlib -using LinearAlgebra: MulAddMul, wrap, diagm, BlasReal +using LinearAlgebra: MulAddMul, wrap, diagm, BlasReal, AbstractTriangular ## transpose and adjoint @@ -518,6 +518,37 @@ end @. C = s * X * alpha + C * beta return C end + function LinearAlgebra._generic_matmatmul_nonadjtrans!(C::AbstractGPUVecOrMat{R}, A::AbstractTriangular, B::AbstractTriangular, alpha::Number, beta::Number) where {R} + if size(A,2) != size(B,1) + throw(DimensionMismatch(lazy"matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))")) + end + if size(C,1) != size(A,1) || size(C,2) != size(B,2) + throw(DimensionMismatch(lazy"result C has dimensions $(size(C)), needs $((size(A,1),size(B,2)))")) + end + if isempty(A) || isempty(B) + return fill!(C, zero(R)) + end + upperA = A isa Union{UnitUpperTriangular, UpperTriangular} + upperB = B isa Union{UnitUpperTriangular, UpperTriangular} + # this function is ONLY reached if beta is not zero + @kernel function trimatmul(C, A, B, alpha, beta) + idx = @index(Global, Linear) + assume.(size(C) .> 0) + i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 + l, m, n = size(A, 1), size(B, 1), size(B, 2) + + @inbounds if i <= l && j <= n + Cij = beta * C[i,j] + Cij += A[i,i] * B[i,j] * alpha + for k in (upperA ? (i + 1) : 1):(upperA ? m : (i - 1)) + Cij += alpha * A[i,k] * B[k,j] + end + C[i,j] = Cij + end + end + trimatmul(get_backend(C))(C, A, B, alpha, beta; ndrange = length(C)) + return C + end end function generic_trimatmul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix{T}, B::AbstractGPUVecOrMat{S}) where {T,S,R} diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index 8875dd54..4efd36ab 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -185,6 +185,19 @@ mul!(C, collect(A), f(TR(collect(B)))) @test collect(Ct) ≈ C end + @testset "matmul! with nonzero β ($TR1 x $TR2)" for T in (Float32, ComplexF32), TR1 in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular), TR2 in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) + if !(T in eltypes) + continue + end + n = 128 + A = AT(rand(T, n, n)) + B = AT(rand(T, n, n)) + Ct = AT(rand(T, n, n)) + C = collect(Ct) + mul!(Ct, TR1(A), TR2(B), 1, -1) + mul!(C, TR1(collect(A)), TR2(collect(B)), 1, -1) + @test collect(Ct) ≈ C + end end end From fb645e43933f923ab41d1cd56be665454828316a Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 3 Jun 2026 04:54:24 -0400 Subject: [PATCH 2/9] Add support for older Julias too --- src/host/linalg.jl | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index eb630902..81e12f07 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -551,6 +551,40 @@ end end end +@static if VERSION < v"1.12.0-rc" + function LinearAlgebra._generic_matmatmul!(C::AbstractGPUVecOrMat{R}, tA::Char, tB::Char, A::AbstractTriangular, B::AbstractTriangular, add::LinearAlgebra.MulAddMul) where {R} + if size(A,2) != size(B,1) + throw(DimensionMismatch(lazy"matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))")) + end + if size(C,1) != size(A,1) || size(C,2) != size(B,2) + throw(DimensionMismatch(lazy"result C has dimensions $(size(C)), needs $((size(A,1),size(B,2)))")) + end + if isempty(A) || isempty(B) + return fill!(C, zero(R)) + end + upperA = A isa Union{UnitUpperTriangular, UpperTriangular} + upperB = B isa Union{UnitUpperTriangular, UpperTriangular} + @kernel function trimatmul(C, A, B) + idx = @index(Global, Linear) + assume.(size(C) .> 0) + i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 + l, m, n = size(A, 1), size(B, 1), size(B, 2) + + @inbounds if i <= l && j <= n + z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j]) + Cij = convert(promote_type(R, typeof(z2)), z2) + Cij += A[i,i] * B[i,j] + for k in (upperA ? (i + 1) : 1):(upperA ? m : (i - 1)) + Cij += A[i,k] * B[k,j] + end + C[i,j] = add(Cij, C[i, j]) + end + end + trimatmul(get_backend(C))(C, A, B; ndrange = length(C)) + return C + end +end + function generic_trimatmul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix{T}, B::AbstractGPUVecOrMat{S}) where {T,S,R} if size(A,2) != size(B,1) throw(DimensionMismatch(lazy"matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))")) From 2034e0565a2a75a2a85c3dcac4a7ddb653a3a319 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 3 Jun 2026 05:36:09 -0400 Subject: [PATCH 3/9] Add istriu/istril for JLArrays and use in methods --- lib/JLArrays/src/JLArrays.jl | 10 ++++++++++ src/host/linalg.jl | 38 +++++++++++++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/lib/JLArrays/src/JLArrays.jl b/lib/JLArrays/src/JLArrays.jl index 4d2dad41..a608f6bd 100644 --- a/lib/JLArrays/src/JLArrays.jl +++ b/lib/JLArrays/src/JLArrays.jl @@ -451,6 +451,16 @@ Adapt.adapt_storage(::Type{<:JLArray{T}}, xs::AbstractArray) where {T} = # adapt back to the CPU Adapt.adapt_storage(::Type{Array}, xs::JLArray) = convert(Array, xs) +## triangular checks +LinearAlgebra.istriu(A::LinearAlgebra.UnitUpperTriangular{T, JLMatrix{T}}) where {T} = true +LinearAlgebra.istriu(A::LinearAlgebra.UpperTriangular{T, JLMatrix{T}}) where {T} = true +LinearAlgebra.istril(A::LinearAlgebra.UnitUpperTriangular{T, JLMatrix{T}}) where {T} = false +LinearAlgebra.istril(A::LinearAlgebra.UpperTriangular{T, JLMatrix{T}}) where {T} = false + +LinearAlgebra.istriu(A::LinearAlgebra.UnitLowerTriangular{T, JLMatrix{T}}) where {T} = false +LinearAlgebra.istriu(A::LinearAlgebra.LowerTriangular{T, JLMatrix{T}}) where {T} = false +LinearAlgebra.istril(A::LinearAlgebra.UnitLowerTriangular{T, JLMatrix{T}}) where {T} = true +LinearAlgebra.istril(A::LinearAlgebra.LowerTriangular{T, JLMatrix{T}}) where {T} = true ## conversions diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 81e12f07..88e7e8ab 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -551,8 +551,44 @@ end end end -@static if VERSION < v"1.12.0-rc" +@static if VERSION < v"1.11.0-rc" function LinearAlgebra._generic_matmatmul!(C::AbstractGPUVecOrMat{R}, tA::Char, tB::Char, A::AbstractTriangular, B::AbstractTriangular, add::LinearAlgebra.MulAddMul) where {R} + if size(A,2) != size(B,1) + throw(DimensionMismatch(lazy"matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))")) + end + if size(C,1) != size(A,1) || size(C,2) != size(B,2) + throw(DimensionMismatch(lazy"result C has dimensions $(size(C)), needs $((size(A,1),size(B,2)))")) + end + if isempty(A) || isempty(B) + return fill!(C, zero(R)) + end + wA = LinearAlgebra.wrap(A, tA) + wB = LinearAlgebra.wrap(B, tB) + upperA = istriu(A) + upperB = istriu(B) + @kernel function trimatmul(C, A, B) + idx = @index(Global, Linear) + assume.(size(C) .> 0) + i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 + l, m, n = size(A, 1), size(B, 1), size(B, 2) + + @inbounds if i <= l && j <= n + z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j]) + Cij = convert(promote_type(R, typeof(z2)), z2) + Cij += A[i,i] * B[i,j] + for k in (upperA ? (i + 1) : 1):(upperA ? m : (i - 1)) + Cij += A[i,k] * B[k,j] + end + C[i,j] = add(Cij, C[i, j]) + end + end + trimatmul(get_backend(C))(C, A, B; ndrange = length(C)) + return C + end +end + +@static if v"1.11.0-rc" < VERSION < v"1.12.0-rc" + function LinearAlgebra._generic_matmatmul!(C::AbstractGPUVecOrMat{R}, A::AbstractTriangular, B::AbstractTriangular, add::LinearAlgebra.MulAddMul) where {R} if size(A,2) != size(B,1) throw(DimensionMismatch(lazy"matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))")) end From d8d5db0775142fa452f84a390c817db6f451e122 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 3 Jun 2026 14:42:22 +0200 Subject: [PATCH 4/9] Generalize triangular istriu/istril to all GPU arrays. The redundant `=true` definitions are dropped (base already short-circuits the matching direction), and the load-bearing opposite-triangle queries move from JLArrays to core, keyed on the GPU array parent so every backend gets them. The matmul kernels now use istriu/istril uniformly. Co-Authored-By: Claude Opus 4.8 (1M context) --- lib/JLArrays/src/JLArrays.jl | 11 ----------- src/host/linalg.jl | 15 +++++++++++---- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/lib/JLArrays/src/JLArrays.jl b/lib/JLArrays/src/JLArrays.jl index a608f6bd..0efc5a5e 100644 --- a/lib/JLArrays/src/JLArrays.jl +++ b/lib/JLArrays/src/JLArrays.jl @@ -451,17 +451,6 @@ Adapt.adapt_storage(::Type{<:JLArray{T}}, xs::AbstractArray) where {T} = # adapt back to the CPU Adapt.adapt_storage(::Type{Array}, xs::JLArray) = convert(Array, xs) -## triangular checks -LinearAlgebra.istriu(A::LinearAlgebra.UnitUpperTriangular{T, JLMatrix{T}}) where {T} = true -LinearAlgebra.istriu(A::LinearAlgebra.UpperTriangular{T, JLMatrix{T}}) where {T} = true -LinearAlgebra.istril(A::LinearAlgebra.UnitUpperTriangular{T, JLMatrix{T}}) where {T} = false -LinearAlgebra.istril(A::LinearAlgebra.UpperTriangular{T, JLMatrix{T}}) where {T} = false - -LinearAlgebra.istriu(A::LinearAlgebra.UnitLowerTriangular{T, JLMatrix{T}}) where {T} = false -LinearAlgebra.istriu(A::LinearAlgebra.LowerTriangular{T, JLMatrix{T}}) where {T} = false -LinearAlgebra.istril(A::LinearAlgebra.UnitLowerTriangular{T, JLMatrix{T}}) where {T} = true -LinearAlgebra.istril(A::LinearAlgebra.LowerTriangular{T, JLMatrix{T}}) where {T} = true - ## conversions Base.convert(::Type{T}, x::T) where T <: JLArray = x diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 88e7e8ab..586dac8f 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -264,6 +264,13 @@ function LinearAlgebra.istril(A::AbstractGPUMatrix, k::Integer = 0) mapreduce(mapper, reducer, A, eachindex(IndexCartesian(), A); init=true) end +# the structure of a triangular wrapper is encoded in its type, so we can answer the +# opposite-triangle queries in O(1) instead of scanning (which would scalar-index the parent) +LinearAlgebra.istriu(::LowerTriangular{<:Any, <:AbstractGPUMatrix}) = false +LinearAlgebra.istriu(::UnitLowerTriangular{<:Any, <:AbstractGPUMatrix}) = false +LinearAlgebra.istril(::UpperTriangular{<:Any, <:AbstractGPUMatrix}) = false +LinearAlgebra.istril(::UnitUpperTriangular{<:Any, <:AbstractGPUMatrix}) = false + ## diagonal @@ -528,8 +535,8 @@ end if isempty(A) || isempty(B) return fill!(C, zero(R)) end - upperA = A isa Union{UnitUpperTriangular, UpperTriangular} - upperB = B isa Union{UnitUpperTriangular, UpperTriangular} + upperA = istriu(A) + upperB = istriu(B) # this function is ONLY reached if beta is not zero @kernel function trimatmul(C, A, B, alpha, beta) idx = @index(Global, Linear) @@ -598,8 +605,8 @@ end if isempty(A) || isempty(B) return fill!(C, zero(R)) end - upperA = A isa Union{UnitUpperTriangular, UpperTriangular} - upperB = B isa Union{UnitUpperTriangular, UpperTriangular} + upperA = istriu(A) + upperB = istriu(B) @kernel function trimatmul(C, A, B) idx = @index(Global, Linear) assume.(size(C) .> 0) From 350191ad574c4bc8fa187f7cccee4d49a1413f26 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 3 Jun 2026 14:49:45 +0200 Subject: [PATCH 5/9] =?UTF-8?q?Treat=20C=20as=20write-only=20in=20triangul?= =?UTF-8?q?ar=20matmul=20when=20=CE=B2=3D0.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `mul!(C, A, B, α, 0)` with α≠1 reaches the kernel, where `β * C[i,j]` turned an uninitialized NaN/Inf into NaN. Guard the β term and add a regression test that initializes C with NaN. Co-Authored-By: Claude Opus 4.8 (1M context) --- src/host/linalg.jl | 4 ++-- test/testsuite/linalg.jl | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 586dac8f..fa0d1cc9 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -537,7 +537,6 @@ end end upperA = istriu(A) upperB = istriu(B) - # this function is ONLY reached if beta is not zero @kernel function trimatmul(C, A, B, alpha, beta) idx = @index(Global, Linear) assume.(size(C) .> 0) @@ -545,7 +544,8 @@ end l, m, n = size(A, 1), size(B, 1), size(B, 2) @inbounds if i <= l && j <= n - Cij = beta * C[i,j] + # treat C as write-only when beta is zero (it may hold NaN/Inf) + Cij = iszero(beta) ? zero(R) : beta * C[i,j] Cij += A[i,i] * B[i,j] * alpha for k in (upperA ? (i + 1) : 1):(upperA ? m : (i - 1)) Cij += alpha * A[i,k] * B[k,j] diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index 4efd36ab..8307ecf8 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -198,6 +198,20 @@ mul!(C, TR1(collect(A)), TR2(collect(B)), 1, -1) @test collect(Ct) ≈ C end + @testset "matmul! is write-only for β=0 ($TR1 x $TR2)" for T in (Float32, ComplexF32), TR1 in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular), TR2 in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) + if !(T in eltypes) + continue + end + n = 128 + A = AT(rand(T, n, n)) + B = AT(rand(T, n, n)) + # C starts as NaN: with β=0 it must be ignored (not multiplied), α≠1 to reach this path + Ct = AT(fill(T(NaN), n, n)) + C = collect(Ct) + mul!(Ct, TR1(A), TR2(B), 2, 0) + mul!(C, TR1(collect(A)), TR2(collect(B)), 2, 0) + @test collect(Ct) ≈ C + end end end From 6523615992ccd0e202a91d39fb1c71cc2b772be7 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 3 Jun 2026 14:51:28 +0200 Subject: [PATCH 6/9] Remove dead variables from triangular matmul kernels. upperB was never read (B's structure comes from its getindex), and the 1.10 kernel's wA/wB wraps were computed but never passed to the kernel. Co-Authored-By: Claude Opus 4.8 (1M context) --- src/host/linalg.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index fa0d1cc9..61ab8997 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -536,7 +536,6 @@ end return fill!(C, zero(R)) end upperA = istriu(A) - upperB = istriu(B) @kernel function trimatmul(C, A, B, alpha, beta) idx = @index(Global, Linear) assume.(size(C) .> 0) @@ -569,10 +568,7 @@ end if isempty(A) || isempty(B) return fill!(C, zero(R)) end - wA = LinearAlgebra.wrap(A, tA) - wB = LinearAlgebra.wrap(B, tB) upperA = istriu(A) - upperB = istriu(B) @kernel function trimatmul(C, A, B) idx = @index(Global, Linear) assume.(size(C) .> 0) @@ -606,7 +602,6 @@ end return fill!(C, zero(R)) end upperA = istriu(A) - upperB = istriu(B) @kernel function trimatmul(C, A, B) idx = @index(Global, Linear) assume.(size(C) .> 0) From 6980ddb14cec406fadbb7e41443b0e16771d63fb Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 3 Jun 2026 14:56:57 +0200 Subject: [PATCH 7/9] Determine triangle from wrapper type, not istriu/istril. The istriu/istril overrides were unsound: UpperTriangular(Diagonal(...)) is also lower triangular, so a type-only `istril = false` lies. The kernels only need the wrapper's structural guarantee, which a plain `isa` check gives directly and without scalar indexing. Co-Authored-By: Claude Opus 4.8 (1M context) --- src/host/linalg.jl | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 61ab8997..626c8494 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -264,13 +264,6 @@ function LinearAlgebra.istril(A::AbstractGPUMatrix, k::Integer = 0) mapreduce(mapper, reducer, A, eachindex(IndexCartesian(), A); init=true) end -# the structure of a triangular wrapper is encoded in its type, so we can answer the -# opposite-triangle queries in O(1) instead of scanning (which would scalar-index the parent) -LinearAlgebra.istriu(::LowerTriangular{<:Any, <:AbstractGPUMatrix}) = false -LinearAlgebra.istriu(::UnitLowerTriangular{<:Any, <:AbstractGPUMatrix}) = false -LinearAlgebra.istril(::UpperTriangular{<:Any, <:AbstractGPUMatrix}) = false -LinearAlgebra.istril(::UnitUpperTriangular{<:Any, <:AbstractGPUMatrix}) = false - ## diagonal @@ -535,13 +528,13 @@ end if isempty(A) || isempty(B) return fill!(C, zero(R)) end - upperA = istriu(A) + upperA = A isa Union{UnitUpperTriangular, UpperTriangular} @kernel function trimatmul(C, A, B, alpha, beta) idx = @index(Global, Linear) assume.(size(C) .> 0) i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 l, m, n = size(A, 1), size(B, 1), size(B, 2) - + @inbounds if i <= l && j <= n # treat C as write-only when beta is zero (it may hold NaN/Inf) Cij = iszero(beta) ? zero(R) : beta * C[i,j] @@ -568,13 +561,13 @@ end if isempty(A) || isempty(B) return fill!(C, zero(R)) end - upperA = istriu(A) + upperA = A isa Union{UnitUpperTriangular, UpperTriangular} @kernel function trimatmul(C, A, B) idx = @index(Global, Linear) assume.(size(C) .> 0) i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 l, m, n = size(A, 1), size(B, 1), size(B, 2) - + @inbounds if i <= l && j <= n z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j]) Cij = convert(promote_type(R, typeof(z2)), z2) @@ -601,13 +594,13 @@ end if isempty(A) || isempty(B) return fill!(C, zero(R)) end - upperA = istriu(A) + upperA = A isa Union{UnitUpperTriangular, UpperTriangular} @kernel function trimatmul(C, A, B) idx = @index(Global, Linear) assume.(size(C) .> 0) i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 l, m, n = size(A, 1), size(B, 1), size(B, 2) - + @inbounds if i <= l && j <= n z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j]) Cij = convert(promote_type(R, typeof(z2)), z2) From c07fd9f4da9a60922c5c27efd9d94f75e8f4eec9 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 3 Jun 2026 15:05:23 +0200 Subject: [PATCH 8/9] Deduplicate triangular matmul kernels into one helper. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The three version-gated tri×tri methods shared a near-identical body and kernel. Extract _triangular_matmatmul!(C, A, B, α, β) and reduce each to a thin forwarder (the pre-1.12 ones unpack α/β from the MulAddMul). Net -47 lines; verified on Julia 1.10, 1.11 and 1.12. Co-Authored-By: Claude Opus 4.8 (1M context) --- src/host/linalg.jl | 133 +++++++++++++++------------------------------ 1 file changed, 43 insertions(+), 90 deletions(-) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 626c8494..999c7226 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -487,6 +487,43 @@ function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::Abs LinearAlgebra.@stable_muladdmul generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(a, b)) end end + +# triangular × triangular matmul: C = α·(A·B) + β·C, with both A and B triangular. +# Shared by the version-specific LinearAlgebra entry points below, which differ only in how +# α/β arrive (raw scalars on ≥1.12, a MulAddMul on older). The wrapper type — not the data — +# fixes which triangle holds the nonzeros, so we read it with a plain `isa` check. +function _triangular_matmatmul!(C::AbstractGPUVecOrMat{R}, A::AbstractTriangular, B::AbstractTriangular, alpha::Number, beta::Number) where {R} + if size(A,2) != size(B,1) + throw(DimensionMismatch(lazy"matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))")) + end + if size(C,1) != size(A,1) || size(C,2) != size(B,2) + throw(DimensionMismatch(lazy"result C has dimensions $(size(C)), needs $((size(A,1),size(B,2)))")) + end + if isempty(A) || isempty(B) + return fill!(C, zero(R)) + end + upperA = A isa Union{UnitUpperTriangular, UpperTriangular} + @kernel function trimatmul(C, A, B, alpha, beta) + idx = @index(Global, Linear) + assume.(size(C) .> 0) + i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 + l, m, n = size(A, 1), size(B, 1), size(B, 2) + + @inbounds if i <= l && j <= n + z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j]) + Cij = convert(promote_type(R, typeof(z2)), z2) + Cij += A[i,i] * B[i,j] + for k in (upperA ? (i + 1) : 1):(upperA ? m : (i - 1)) + Cij += A[i,k] * B[k,j] + end + # treat C as write-only when beta is zero (it may hold NaN/Inf) + C[i,j] = iszero(beta) ? alpha * Cij : alpha * Cij + beta * C[i,j] + end + end + trimatmul(get_backend(C))(C, A, B, alpha, beta; ndrange = length(C)) + return C +end + @static if VERSION ≥ v"1.12.0-rc" # we need to use the generic wrapper to avoid dispatch to the 2x2or3x3 method using LinearAlgebra: generic_matmatmul_wrapper!, BlasFlag @@ -518,102 +555,18 @@ end @. C = s * X * alpha + C * beta return C end - function LinearAlgebra._generic_matmatmul_nonadjtrans!(C::AbstractGPUVecOrMat{R}, A::AbstractTriangular, B::AbstractTriangular, alpha::Number, beta::Number) where {R} - if size(A,2) != size(B,1) - throw(DimensionMismatch(lazy"matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))")) - end - if size(C,1) != size(A,1) || size(C,2) != size(B,2) - throw(DimensionMismatch(lazy"result C has dimensions $(size(C)), needs $((size(A,1),size(B,2)))")) - end - if isempty(A) || isempty(B) - return fill!(C, zero(R)) - end - upperA = A isa Union{UnitUpperTriangular, UpperTriangular} - @kernel function trimatmul(C, A, B, alpha, beta) - idx = @index(Global, Linear) - assume.(size(C) .> 0) - i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 - l, m, n = size(A, 1), size(B, 1), size(B, 2) - - @inbounds if i <= l && j <= n - # treat C as write-only when beta is zero (it may hold NaN/Inf) - Cij = iszero(beta) ? zero(R) : beta * C[i,j] - Cij += A[i,i] * B[i,j] * alpha - for k in (upperA ? (i + 1) : 1):(upperA ? m : (i - 1)) - Cij += alpha * A[i,k] * B[k,j] - end - C[i,j] = Cij - end - end - trimatmul(get_backend(C))(C, A, B, alpha, beta; ndrange = length(C)) - return C - end + LinearAlgebra._generic_matmatmul_nonadjtrans!(C::AbstractGPUVecOrMat, A::AbstractTriangular, B::AbstractTriangular, alpha::Number, beta::Number) = + _triangular_matmatmul!(C, A, B, alpha, beta) end @static if VERSION < v"1.11.0-rc" - function LinearAlgebra._generic_matmatmul!(C::AbstractGPUVecOrMat{R}, tA::Char, tB::Char, A::AbstractTriangular, B::AbstractTriangular, add::LinearAlgebra.MulAddMul) where {R} - if size(A,2) != size(B,1) - throw(DimensionMismatch(lazy"matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))")) - end - if size(C,1) != size(A,1) || size(C,2) != size(B,2) - throw(DimensionMismatch(lazy"result C has dimensions $(size(C)), needs $((size(A,1),size(B,2)))")) - end - if isempty(A) || isempty(B) - return fill!(C, zero(R)) - end - upperA = A isa Union{UnitUpperTriangular, UpperTriangular} - @kernel function trimatmul(C, A, B) - idx = @index(Global, Linear) - assume.(size(C) .> 0) - i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 - l, m, n = size(A, 1), size(B, 1), size(B, 2) - - @inbounds if i <= l && j <= n - z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j]) - Cij = convert(promote_type(R, typeof(z2)), z2) - Cij += A[i,i] * B[i,j] - for k in (upperA ? (i + 1) : 1):(upperA ? m : (i - 1)) - Cij += A[i,k] * B[k,j] - end - C[i,j] = add(Cij, C[i, j]) - end - end - trimatmul(get_backend(C))(C, A, B; ndrange = length(C)) - return C - end + LinearAlgebra._generic_matmatmul!(C::AbstractGPUVecOrMat, tA::Char, tB::Char, A::AbstractTriangular, B::AbstractTriangular, add::LinearAlgebra.MulAddMul) = + _triangular_matmatmul!(C, A, B, add.alpha, add.beta) end @static if v"1.11.0-rc" < VERSION < v"1.12.0-rc" - function LinearAlgebra._generic_matmatmul!(C::AbstractGPUVecOrMat{R}, A::AbstractTriangular, B::AbstractTriangular, add::LinearAlgebra.MulAddMul) where {R} - if size(A,2) != size(B,1) - throw(DimensionMismatch(lazy"matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))")) - end - if size(C,1) != size(A,1) || size(C,2) != size(B,2) - throw(DimensionMismatch(lazy"result C has dimensions $(size(C)), needs $((size(A,1),size(B,2)))")) - end - if isempty(A) || isempty(B) - return fill!(C, zero(R)) - end - upperA = A isa Union{UnitUpperTriangular, UpperTriangular} - @kernel function trimatmul(C, A, B) - idx = @index(Global, Linear) - assume.(size(C) .> 0) - i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 - l, m, n = size(A, 1), size(B, 1), size(B, 2) - - @inbounds if i <= l && j <= n - z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j]) - Cij = convert(promote_type(R, typeof(z2)), z2) - Cij += A[i,i] * B[i,j] - for k in (upperA ? (i + 1) : 1):(upperA ? m : (i - 1)) - Cij += A[i,k] * B[k,j] - end - C[i,j] = add(Cij, C[i, j]) - end - end - trimatmul(get_backend(C))(C, A, B; ndrange = length(C)) - return C - end + LinearAlgebra._generic_matmatmul!(C::AbstractGPUVecOrMat, A::AbstractTriangular, B::AbstractTriangular, add::LinearAlgebra.MulAddMul) = + _triangular_matmatmul!(C, A, B, add.alpha, add.beta) end function generic_trimatmul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix{T}, B::AbstractGPUVecOrMat{S}) where {T,S,R} From 6119dbce513d5e9c5ed7818922615dde0634d76e Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 3 Jun 2026 15:08:33 +0200 Subject: [PATCH 9/9] =?UTF-8?q?Test=20triangular=20matmul=20with=20=CE=B1?= =?UTF-8?q?=E2=89=A01=20and=20=CE=B2=E2=89=A00=20together.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Existing cases only covered α=1 (scaling is a no-op) or β=0 (the β·C term drops), so the general α·(A·B) + β·C path was untested. Co-Authored-By: Claude Opus 4.8 (1M context) --- test/testsuite/linalg.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index 8307ecf8..1b404d96 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -197,6 +197,12 @@ mul!(Ct, TR1(A), TR2(B), 1, -1) mul!(C, TR1(collect(A)), TR2(collect(B)), 1, -1) @test collect(Ct) ≈ C + # general case: α≠1 and β≠0 together + Et = AT(rand(T, n, n)) + E = collect(Et) + mul!(Et, TR1(A), TR2(B), 3, 2) + mul!(E, TR1(collect(A)), TR2(collect(B)), 3, 2) + @test collect(Et) ≈ E end @testset "matmul! is write-only for β=0 ($TR1 x $TR2)" for T in (Float32, ComplexF32), TR1 in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular), TR2 in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) if !(T in eltypes)