diff --git a/Project.toml b/Project.toml index 4fff4b82..c35adc5c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "GPUArrays" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "11.5.5" +version = "11.5.6" [workspace] projects = ["lib/GPUArraysCore", "lib/JLArrays", "test", "docs"] @@ -26,7 +26,7 @@ JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" JLD2Ext = "JLD2" [compat] -Adapt = "4.0" +Adapt = "4.6.1" GPUArraysCore = "= 0.2.0" JLD2 = "0.4, 0.5, 0.6" KernelAbstractions = "0.9.28, 0.10" diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 999c7226..f216386f 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -530,6 +530,11 @@ end function LinearAlgebra.generic_matmatmul_wrapper!(C::AbstractGPUMatrix{T}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{T}, alpha::Number, beta::Number, val::LinearAlgebra.BlasFlag.SyrkHerkGemm) where {T} LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta) end + # Symmetric/Hermitian inputs with BLAS eltypes would otherwise dispatch to BLAS.symm!/ + # hemm!: GPU arrays are DenseArrays, so they match the StridedMatrix{<:BlasFloat} methods + function LinearAlgebra.generic_matmatmul_wrapper!(C::AbstractGPUMatrix{T}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{T}, alpha::Number, beta::Number, val::LinearAlgebra.BlasFlag.SymmHemmGeneric) where {T} + LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta) + end # need to support mixed complex/real types too #function LinearAlgebra.generic_matmatmul_wrapper!(C::AbstractGPUMatrix{Complex{T}}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{Complex{T}}, B::AbstractGPUVecOrMat{T}, alpha::Number, beta::Number, val::V) where {T<:BlasReal, V<:LinearAlgebra.BlasFlag.SyrkHerkGemm} # LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta) diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index 1b404d96..d8e587e0 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -219,6 +219,45 @@ @test collect(Ct) ≈ C end end + + @testset "mul! + Symmetric/Hermitian" begin + # with BLAS eltypes these dispatch through generic_matmatmul_wrapper!'s + # SymmHemmGeneric path, which must not end up in BLAS.symm!/hemm! + @testset "$W{$T}, uplo=$uplo" for T in (Float32, ComplexF32), W in (Symmetric, Hermitian), uplo in (:U, :L) + if !(T in eltypes) + continue + end + n = 128 + A = AT(rand(T, n, n)) + B = AT(rand(T, n, n)) + b = AT(rand(T, n)) + + # matrix-matrix, wrapped on either side + Ct = AT(zeros(T, n, n)) + C = zeros(T, n, n) + mul!(Ct, W(A, uplo), B) + mul!(C, W(collect(A), uplo), collect(B)) + @test collect(Ct) ≈ C + + mul!(Ct, B, W(A, uplo)) + mul!(C, collect(B), W(collect(A), uplo)) + @test collect(Ct) ≈ C + + # general case: α≠1 and β≠0 together + Et = AT(rand(T, n, n)) + E = collect(Et) + mul!(Et, W(A, uplo), B, 3, 2) + mul!(E, W(collect(A), uplo), collect(B), 3, 2) + @test collect(Et) ≈ E + + # matrix-vector + ct = AT(zeros(T, n)) + c = zeros(T, n) + mul!(ct, W(A, uplo), b) + mul!(c, W(collect(A), uplo), collect(b)) + @test collect(ct) ≈ c + end + end end @testset "diagm" begin