diff --git a/lib/JLArrays/src/JLArrays.jl b/lib/JLArrays/src/JLArrays.jl index 4d2dad41..0efc5a5e 100644 --- a/lib/JLArrays/src/JLArrays.jl +++ b/lib/JLArrays/src/JLArrays.jl @@ -451,7 +451,6 @@ Adapt.adapt_storage(::Type{<:JLArray{T}}, xs::AbstractArray) where {T} = # adapt back to the CPU Adapt.adapt_storage(::Type{Array}, xs::JLArray) = convert(Array, xs) - ## conversions Base.convert(::Type{T}, x::T) where T <: JLArray = x diff --git a/src/host/linalg.jl b/src/host/linalg.jl index a46b85d2..999c7226 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -1,6 +1,6 @@ # integration with LinearAlgebra stdlib -using LinearAlgebra: MulAddMul, wrap, diagm, BlasReal +using LinearAlgebra: MulAddMul, wrap, diagm, BlasReal, AbstractTriangular ## transpose and adjoint @@ -487,6 +487,43 @@ function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::Abs LinearAlgebra.@stable_muladdmul generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(a, b)) end end + +# triangular × triangular matmul: C = α·(A·B) + β·C, with both A and B triangular. +# Shared by the version-specific LinearAlgebra entry points below, which differ only in how +# α/β arrive (raw scalars on ≥1.12, a MulAddMul on older). The wrapper type — not the data — +# fixes which triangle holds the nonzeros, so we read it with a plain `isa` check. +function _triangular_matmatmul!(C::AbstractGPUVecOrMat{R}, A::AbstractTriangular, B::AbstractTriangular, alpha::Number, beta::Number) where {R} + if size(A,2) != size(B,1) + throw(DimensionMismatch(lazy"matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))")) + end + if size(C,1) != size(A,1) || size(C,2) != size(B,2) + throw(DimensionMismatch(lazy"result C has dimensions $(size(C)), needs $((size(A,1),size(B,2)))")) + end + if isempty(A) || isempty(B) + return fill!(C, zero(R)) + end + upperA = A isa Union{UnitUpperTriangular, UpperTriangular} + @kernel function trimatmul(C, A, B, alpha, beta) + idx = @index(Global, Linear) + assume.(size(C) .> 0) + i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 + l, m, n = size(A, 1), size(B, 1), size(B, 2) + + @inbounds if i <= l && j <= n + z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j]) + Cij = convert(promote_type(R, typeof(z2)), z2) + Cij += A[i,i] * B[i,j] + for k in (upperA ? (i + 1) : 1):(upperA ? m : (i - 1)) + Cij += A[i,k] * B[k,j] + end + # treat C as write-only when beta is zero (it may hold NaN/Inf) + C[i,j] = iszero(beta) ? alpha * Cij : alpha * Cij + beta * C[i,j] + end + end + trimatmul(get_backend(C))(C, A, B, alpha, beta; ndrange = length(C)) + return C +end + @static if VERSION ≥ v"1.12.0-rc" # we need to use the generic wrapper to avoid dispatch to the 2x2or3x3 method using LinearAlgebra: generic_matmatmul_wrapper!, BlasFlag @@ -518,6 +555,18 @@ end @. C = s * X * alpha + C * beta return C end + LinearAlgebra._generic_matmatmul_nonadjtrans!(C::AbstractGPUVecOrMat, A::AbstractTriangular, B::AbstractTriangular, alpha::Number, beta::Number) = + _triangular_matmatmul!(C, A, B, alpha, beta) +end + +@static if VERSION < v"1.11.0-rc" + LinearAlgebra._generic_matmatmul!(C::AbstractGPUVecOrMat, tA::Char, tB::Char, A::AbstractTriangular, B::AbstractTriangular, add::LinearAlgebra.MulAddMul) = + _triangular_matmatmul!(C, A, B, add.alpha, add.beta) +end + +@static if v"1.11.0-rc" < VERSION < v"1.12.0-rc" + LinearAlgebra._generic_matmatmul!(C::AbstractGPUVecOrMat, A::AbstractTriangular, B::AbstractTriangular, add::LinearAlgebra.MulAddMul) = + _triangular_matmatmul!(C, A, B, add.alpha, add.beta) end function generic_trimatmul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix{T}, B::AbstractGPUVecOrMat{S}) where {T,S,R} diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index 8875dd54..1b404d96 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -185,6 +185,39 @@ mul!(C, collect(A), f(TR(collect(B)))) @test collect(Ct) ≈ C end + @testset "matmul! with nonzero β ($TR1 x $TR2)" for T in (Float32, ComplexF32), TR1 in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular), TR2 in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) + if !(T in eltypes) + continue + end + n = 128 + A = AT(rand(T, n, n)) + B = AT(rand(T, n, n)) + Ct = AT(rand(T, n, n)) + C = collect(Ct) + mul!(Ct, TR1(A), TR2(B), 1, -1) + mul!(C, TR1(collect(A)), TR2(collect(B)), 1, -1) + @test collect(Ct) ≈ C + # general case: α≠1 and β≠0 together + Et = AT(rand(T, n, n)) + E = collect(Et) + mul!(Et, TR1(A), TR2(B), 3, 2) + mul!(E, TR1(collect(A)), TR2(collect(B)), 3, 2) + @test collect(Et) ≈ E + end + @testset "matmul! is write-only for β=0 ($TR1 x $TR2)" for T in (Float32, ComplexF32), TR1 in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular), TR2 in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) + if !(T in eltypes) + continue + end + n = 128 + A = AT(rand(T, n, n)) + B = AT(rand(T, n, n)) + # C starts as NaN: with β=0 it must be ignored (not multiplied), α≠1 to reach this path + Ct = AT(fill(T(NaN), n, n)) + C = collect(Ct) + mul!(Ct, TR1(A), TR2(B), 2, 0) + mul!(C, TR1(collect(A)), TR2(collect(B)), 2, 0) + @test collect(Ct) ≈ C + end end end