Skip to content
1 change: 0 additions & 1 deletion lib/JLArrays/src/JLArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,6 @@ Adapt.adapt_storage(::Type{<:JLArray{T}}, xs::AbstractArray) where {T} =
# adapt back to the CPU
Adapt.adapt_storage(::Type{Array}, xs::JLArray) = convert(Array, xs)


## conversions

Base.convert(::Type{T}, x::T) where T <: JLArray = x
Expand Down
51 changes: 50 additions & 1 deletion src/host/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# integration with LinearAlgebra stdlib

using LinearAlgebra: MulAddMul, wrap, diagm, BlasReal
using LinearAlgebra: MulAddMul, wrap, diagm, BlasReal, AbstractTriangular

## transpose and adjoint

Expand Down Expand Up @@ -487,6 +487,43 @@ function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::Abs
LinearAlgebra.@stable_muladdmul generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(a, b))
end
end

# triangular × triangular matmul: C = α·(A·B) + β·C, with both A and B triangular.
# Shared by the version-specific LinearAlgebra entry points below, which differ only in how
# α/β arrive (raw scalars on ≥1.12, a MulAddMul on older). The wrapper type — not the data —
# fixes which triangle holds the nonzeros, so we read it with a plain `isa` check.
function _triangular_matmatmul!(C::AbstractGPUVecOrMat{R}, A::AbstractTriangular, B::AbstractTriangular, alpha::Number, beta::Number) where {R}
if size(A,2) != size(B,1)
throw(DimensionMismatch(lazy"matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))"))
end
if size(C,1) != size(A,1) || size(C,2) != size(B,2)
throw(DimensionMismatch(lazy"result C has dimensions $(size(C)), needs $((size(A,1),size(B,2)))"))
end
if isempty(A) || isempty(B)
return fill!(C, zero(R))
end
upperA = A isa Union{UnitUpperTriangular, UpperTriangular}
@kernel function trimatmul(C, A, B, alpha, beta)
idx = @index(Global, Linear)
assume.(size(C) .> 0)
i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1
l, m, n = size(A, 1), size(B, 1), size(B, 2)

@inbounds if i <= l && j <= n
z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j])
Cij = convert(promote_type(R, typeof(z2)), z2)
Cij += A[i,i] * B[i,j]
for k in (upperA ? (i + 1) : 1):(upperA ? m : (i - 1))
Cij += A[i,k] * B[k,j]
end
# treat C as write-only when beta is zero (it may hold NaN/Inf)
C[i,j] = iszero(beta) ? alpha * Cij : alpha * Cij + beta * C[i,j]
end
end
trimatmul(get_backend(C))(C, A, B, alpha, beta; ndrange = length(C))
return C
end

@static if VERSION ≥ v"1.12.0-rc"
# we need to use the generic wrapper to avoid dispatch to the 2x2or3x3 method
using LinearAlgebra: generic_matmatmul_wrapper!, BlasFlag
Expand Down Expand Up @@ -518,6 +555,18 @@ end
@. C = s * X * alpha + C * beta
return C
end
LinearAlgebra._generic_matmatmul_nonadjtrans!(C::AbstractGPUVecOrMat, A::AbstractTriangular, B::AbstractTriangular, alpha::Number, beta::Number) =
_triangular_matmatmul!(C, A, B, alpha, beta)
end

@static if VERSION < v"1.11.0-rc"
LinearAlgebra._generic_matmatmul!(C::AbstractGPUVecOrMat, tA::Char, tB::Char, A::AbstractTriangular, B::AbstractTriangular, add::LinearAlgebra.MulAddMul) =
_triangular_matmatmul!(C, A, B, add.alpha, add.beta)
end

@static if v"1.11.0-rc" < VERSION < v"1.12.0-rc"
LinearAlgebra._generic_matmatmul!(C::AbstractGPUVecOrMat, A::AbstractTriangular, B::AbstractTriangular, add::LinearAlgebra.MulAddMul) =
_triangular_matmatmul!(C, A, B, add.alpha, add.beta)
end

function generic_trimatmul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix{T}, B::AbstractGPUVecOrMat{S}) where {T,S,R}
Expand Down
33 changes: 33 additions & 0 deletions test/testsuite/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,39 @@
mul!(C, collect(A), f(TR(collect(B))))
@test collect(Ct) ≈ C
end
@testset "matmul! with nonzero β ($TR1 x $TR2)" for T in (Float32, ComplexF32), TR1 in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular), TR2 in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular)
if !(T in eltypes)
continue
end
n = 128
A = AT(rand(T, n, n))
B = AT(rand(T, n, n))
Ct = AT(rand(T, n, n))
C = collect(Ct)
mul!(Ct, TR1(A), TR2(B), 1, -1)
mul!(C, TR1(collect(A)), TR2(collect(B)), 1, -1)
@test collect(Ct) ≈ C
# general case: α≠1 and β≠0 together
Et = AT(rand(T, n, n))
E = collect(Et)
mul!(Et, TR1(A), TR2(B), 3, 2)
mul!(E, TR1(collect(A)), TR2(collect(B)), 3, 2)
@test collect(Et) ≈ E
end
@testset "matmul! is write-only for β=0 ($TR1 x $TR2)" for T in (Float32, ComplexF32), TR1 in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular), TR2 in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular)
if !(T in eltypes)
continue
end
n = 128
A = AT(rand(T, n, n))
B = AT(rand(T, n, n))
# C starts as NaN: with β=0 it must be ignored (not multiplied), α≠1 to reach this path
Ct = AT(fill(T(NaN), n, n))
C = collect(Ct)
mul!(Ct, TR1(A), TR2(B), 2, 0)
mul!(C, TR1(collect(A)), TR2(collect(B)), 2, 0)
@test collect(Ct) ≈ C
end
end
end

Expand Down