From f7b78558163cf6132422dde3383ac4a3711fb8b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Pawela?= Date: Thu, 28 May 2026 17:32:16 +0200 Subject: [PATCH] fix pointer types in `Hgemm` --- lib/cublas/src/libcublas.jl | 12 ++--- lib/cublas/test/level3/gemm.jl | 88 ++++++++++++++++++---------------- 2 files changed, 52 insertions(+), 48 deletions(-) diff --git a/lib/cublas/src/libcublas.jl b/lib/cublas/src/libcublas.jl index 7cd1d3d9e6..7835028a40 100644 --- a/lib/cublas/src/libcublas.jl +++ b/lib/cublas/src/libcublas.jl @@ -6199,9 +6199,9 @@ end initialize_context() @ccall libcublas.cublasHgemm(handle::cublasHandle_t, transa::cublasOperation_t, transb::cublasOperation_t, m::Cint, n::Cint, k::Cint, - alpha::Ptr{Float16}, A::Ptr{Float16}, lda::Cint, - B::Ptr{Float16}, ldb::Cint, beta::Ptr{Float16}, - C::Ptr{Float16}, ldc::Cint)::cublasStatus_t + alpha::CuRef{Float16}, A::CuPtr{Float16}, lda::Cint, + B::CuPtr{Float16}, ldb::Cint, beta::CuRef{Float16}, + C::CuPtr{Float16}, ldc::Cint)::cublasStatus_t end @checked function cublasHgemm_64(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, @@ -6210,9 +6210,9 @@ end initialize_context() @ccall libcublas.cublasHgemm_64(handle::cublasHandle_t, transa::cublasOperation_t, transb::cublasOperation_t, m::Int64, n::Int64, k::Int64, - alpha::Ptr{Float16}, A::Ptr{Float16}, lda::Int64, - B::Ptr{Float16}, ldb::Int64, beta::Ptr{Float16}, - C::Ptr{Float16}, ldc::Int64)::cublasStatus_t + alpha::CuRef{Float16}, A::CuPtr{Float16}, lda::Int64, + B::CuPtr{Float16}, ldb::Int64, beta::CuRef{Float16}, + C::CuPtr{Float16}, ldc::Int64)::cublasStatus_t end @checked function cublasHgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, diff --git a/lib/cublas/test/level3/gemm.jl b/lib/cublas/test/level3/gemm.jl index 65a299bada..b37f07dfd6 100644 --- a/lib/cublas/test/level3/gemm.jl +++ b/lib/cublas/test/level3/gemm.jl @@ -16,7 +16,7 @@ n = 35 k = 13 @testset "level 3" begin - @testset for elty in [Float32, Float64, ComplexF32, ComplexF64] + @testset for elty in [Float16, Float32, Float64, ComplexF32, ComplexF64] @testset "mul! C = $f(A) * $g(B) * $Ts(a) + C * $Ts(b)" for f in (identity, transpose, adjoint), g in (identity, transpose, adjoint), Ts in (Int, elty) C, A, B = rand(elty, 5, 5), rand(elty, 5, 5), rand(elty, 5, 5) @@ -42,12 +42,14 @@ k = 13 @test Array(dC) ≈ C end - @testset "hermitian" begin - C, A, B = rand(elty, 5, 5), Hermitian(rand(elty, 5, 5)), rand(elty, 5, 5) - dC, dA, dB = CuArray(C), Hermitian(CuArray(A)), CuArray(B) - mul!(dC, dA, dB) - mul!(C, A, B) - @test Array(dC) ≈ C + if !(elty <: Float16) + @testset "hermitian" begin + C, A, B = rand(elty, 5, 5), Hermitian(rand(elty, 5, 5)), rand(elty, 5, 5) + dC, dA, dB = CuArray(C), Hermitian(CuArray(A)), CuArray(B) + mul!(dC, dA, dB) + mul!(C, A, B) + @test Array(dC) ≈ C + end end @testset "gemm!" begin @@ -130,42 +132,44 @@ k = 13 @test C1 ≈ h_C1 @test C1 ≈ h_C2 end - @testset "symm!" begin - alpha = rand(elty) - beta = rand(elty) - sA = rand(elty,m,m) - sA = sA + transpose(sA) - dsA = CuArray(sA) - B = rand(elty,m,n) - C = rand(elty,m,n) - Bbad = rand(elty,m+1,n+1) - d_B = CuArray(B) - d_C = CuArray(C) - d_Bbad = CuArray(Bbad) - cuBLAS.symm!('L','U',alpha,dsA,d_B,beta,d_C) - C = (alpha*sA)*B + beta*C - # compare - h_C = Array(d_C) - @test C ≈ h_C - @test_throws DimensionMismatch cuBLAS.symm!('L','U',alpha,dsA,d_Bbad,beta,d_C) - end - @testset "symm" begin - sA = rand(elty,m,m) - sA = sA + transpose(sA) - dsA = CuArray(sA) - B = rand(elty,m,n) - C = rand(elty,m,n) - Bbad = rand(elty,m+1,n+1) - d_B = CuArray(B) - d_C = CuArray(C) - d_Bbad = CuArray(Bbad) - d_C = cuBLAS.symm('L','U',dsA,d_B) - C = sA*B - # compare - h_C = Array(d_C) - @test C ≈ h_C - @test_throws DimensionMismatch cuBLAS.symm('L','U',dsA,d_Bbad) + if !(elty <: Float16) + @testset "symm!" begin + alpha = rand(elty) + beta = rand(elty) + sA = rand(elty,m,m) + sA = sA + transpose(sA) + dsA = CuArray(sA) + B = rand(elty,m,n) + C = rand(elty,m,n) + Bbad = rand(elty,m+1,n+1) + d_B = CuArray(B) + d_C = CuArray(C) + d_Bbad = CuArray(Bbad) + cuBLAS.symm!('L','U',alpha,dsA,d_B,beta,d_C) + C = (alpha*sA)*B + beta*C + # compare + h_C = Array(d_C) + @test C ≈ h_C + @test_throws DimensionMismatch cuBLAS.symm!('L','U',alpha,dsA,d_Bbad,beta,d_C) + end + @testset "symm" begin + sA = rand(elty,m,m) + sA = sA + transpose(sA) + dsA = CuArray(sA) + B = rand(elty,m,n) + C = rand(elty,m,n) + Bbad = rand(elty,m+1,n+1) + d_B = CuArray(B) + d_C = CuArray(C) + d_Bbad = CuArray(Bbad) + d_C = cuBLAS.symm('L','U',dsA,d_B) + C = sA*B + # compare + h_C = Array(d_C) + @test C ≈ h_C + @test_throws DimensionMismatch cuBLAS.symm('L','U',dsA,d_Bbad) + end end if elty <: Complex