Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using MatrixAlgebraKit: qr_pullback!, lq_pullback!
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!, eigh_vals_pullback!
using MatrixAlgebraKit: svd_pullback!, svd_vals_pullback!
using MatrixAlgebraKit: svd_pushforward!, svd_vals_pushforward!
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward!
using Enzyme
Expand Down Expand Up @@ -257,6 +258,27 @@ for f in (:svd_compact!, :svd_full!)
!isa(USVᴴ, Const) && make_zero!(USVᴴ.dval)
return (nothing, nothing, nothing)
end
function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof($f)},
::Type{RT},
A::Annotation{TA},
USVᴴ::Annotation,
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
) where {RT, TA}
$f(A.val, USVᴴ.val, alg.val)
!isa(A, Const) && !isa(USVᴴ, Const) && svd_pushforward!(A.dval, A.val, USVᴴ.val, USVᴴ.dval)
make_zero!(A.dval)
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return USVᴴ
elseif EnzymeRules.needs_primal(config)
return USVᴴ.val
elseif EnzymeRules.needs_shadow(config)
return USVᴴ.dval
else
return nothing
end
end
end
end

Expand Down Expand Up @@ -467,5 +489,32 @@ function EnzymeRules.reverse(
!isa(S, Const) && !A_is_arg && make_zero!(S.dval)
return (nothing, nothing, nothing)
end
function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof(svd_vals!)},
::Type{RT},
A::Annotation{TA},
S::Annotation,
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
) where {RT, TA}
A_is_arg = !isa(A, Const) && TA <: Diagonal && diagview(A.dval) === S.dval
U, S_, Vᴴ = svd_compact!(A.val, alg.val)
if !isa(A, Const) && !isa(S, Const)
ΔS = A_is_arg ? make_zero(S.dval) : S.dval
svd_vals_pushforward!(A.dval, A.val, (U, Diagonal(diagview(S_)), Vᴴ), ΔS)
A_is_arg && (S.dval .= ΔS)
end
!A_is_arg && make_zero!(A.dval)
copyto!(S.val, diagview(S_))
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return S
elseif EnzymeRules.needs_primal(config)
return S.val
elseif EnzymeRules.needs_shadow(config)
return S.dval
else
return nothing
end
end

end
52 changes: 48 additions & 4 deletions ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pul
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward!
using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback!
using MatrixAlgebraKit: svd_pushforward!, svd_trunc_pushforward!, svd_vals_pushforward!
using MatrixAlgebraKit: TruncatedAlgorithm
using LinearAlgebra

Expand Down Expand Up @@ -511,7 +512,7 @@ for (f!, f) in (
(:svd_compact!, :svd_compact),
)
@eval begin
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm}
@is_primitive Mooncake.DefaultCtx Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
A, dA = arrayify(A_dA)
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
Expand All @@ -535,7 +536,18 @@ for (f!, f) in (
end
return USVᴴ_dUSVᴴ, svd_adjoint
end
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(::Dual{typeof($f!)}, A_dA::Dual, USVᴴ_dUSVᴴ::Dual, alg_dalg::Dual)
A, dA = arrayify(A_dA)
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
$f!(A, USVᴴ, Mooncake.primal(alg_dalg))
svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
return USVᴴ_dUSVᴴ
end
@is_primitive Mooncake.DefaultCtx Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual)
A, dA = arrayify(A_dA)
USVᴴ = $f(A, Mooncake.primal(alg_dalg))
Expand All @@ -558,10 +570,23 @@ for (f!, f) in (
end
return USVᴴ_codual, svd_adjoint
end
function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual)
A, dA = arrayify(A_dA)
USVᴴ = $f(A, Mooncake.primal(alg_dalg))
dUSVᴴ = Mooncake.zero_tangent(USVᴴ)
USVᴴ_dual = Dual(USVᴴ, dUSVᴴ)
U, S, Vᴴ = Mooncake.primal(USVᴴ_dual)
dU_, dS_, dVᴴ_ = Mooncake.tangent(USVᴴ_dual)
U, dU = arrayify(U, dU_)
S, dS = arrayify(S, dS_)
Vᴴ, dVᴴ = arrayify(Vᴴ, dVᴴ_)
svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
return USVᴴ_dual
end
end
end

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
@is_primitive Mooncake.DefaultCtx Tuple{typeof(svd_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual)
# compute primal
A, dA = arrayify(A_dA)
Expand All @@ -577,8 +602,17 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua
end
return S_dS, svd_vals_adjoint
end
function Mooncake.frule!!(::Dual{typeof(svd_vals!)}, A_dA::Dual, S_dS::Dual, alg_dalg::Dual)
# compute primal
A, dA = arrayify(A_dA)
S, dS = arrayify(S_dS)
USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
copy!(S, diagview(USVᴴ[2]))
svd_vals_pushforward!(dA, A, USVᴴ, dS)
return S_dS
end

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals), Any, MatrixAlgebraKit.AbstractAlgorithm}
@is_primitive Mooncake.DefaultCtx Tuple{typeof(svd_vals), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::CoDual)
# compute primal
A, dA = arrayify(A_dA)
Expand All @@ -597,6 +631,16 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co
end
return S_codual, svd_vals_adjoint
end
function Mooncake.frule!!(::Dual{typeof(svd_vals)}, A_dA::Dual, alg_dalg::Dual)
# compute primal
A, dA = arrayify(A_dA)
USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
S = diagview(USVᴴ[2])
S_dual = Dual(S, Mooncake.zero_tangent(S))
S_, dS = arrayify(S_dual)
svd_vals_pushforward!(dA, A, USVᴴ, dS)
return S_dual
end

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
Expand Down
1 change: 1 addition & 0 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ include("pullbacks/svd.jl")
include("pullbacks/polar.jl")

include("pushforwards/polar.jl")
include("pushforwards/svd.jl")

include("precompile.jl")

Expand Down
91 changes: 91 additions & 0 deletions src/pushforwards/svd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind = Colon(); rank_atol = default_pullback_rank_atol(A), kwargs...)
U, Smat, Vᴴ = USVᴴ
m, n = size(U, 1), size(Vᴴ, 2)
(m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)"))
minmn = min(m, n)
S = diagview(Smat)
ΔU, ΔS, ΔVᴴ = ΔUSVᴴ
r = svd_rank(S; rank_atol)

vΔS = view(ΔS, 1:r, 1:r)

vU = view(U, :, 1:r)
vS = view(S, 1:r)
vSmat = view(Smat, 1:r, 1:r)
vVᴴ = view(Vᴴ, 1:r, :)

# compact region
vV = adjoint(vVᴴ)
UΔAV = vU' * ΔA * vV
copyto!(diagview(vΔS), real.(diagview(UΔAV)))
F = inv_safe.(transpose(vS) .- vS)
G = inv_safe.(transpose(vS) .+ vS)
hUΔAV = F .* (UΔAV + UΔAV') ./ 2
aUΔAV = G .* (UΔAV - UΔAV') ./ 2
K̇ = hUΔAV + aUΔAV
Ṁ = hUΔAV - aUΔAV

# check gauge condition
@assert isantihermitian(K̇)
@assert isantihermitian(Ṁ)
K̇diag = diagview(K̇)

∂U = vU * K̇
∂V = vV * Ṁ
# full component
if size(U, 2) > minmn && size(Vᴴ, 1) > minmn
Uperp = view(U, :, (minmn + 1):m)
Vᴴperp = view(Vᴴ, (minmn + 1):n, :)

aUAV = adjoint(Uperp) * A * adjoint(Vᴴperp)

UÃÃV = similar(A, (size(aUAV, 1) + size(aUAV, 2), size(aUAV, 1) + size(aUAV, 2)))
fill!(UÃÃV, 0)
view(UÃÃV, (1:size(aUAV, 1)), size(aUAV, 1) .+ (1:size(aUAV, 2))) .= aUAV
view(UÃÃV, size(aUAV, 1) .+ (1:size(aUAV, 2)), 1:size(aUAV, 1)) .= aUAV'
rhs = vcat(adjoint(Uperp * ΔA * Vᴴ), Vᴴperp * ΔA' * U)
superKM = -sylvester(UÃÃV, Smat, rhs)
K̇perp = view(superKM, 1:size(aUAV, 2))
Ṁperp = view(superKM, (size(aUAV, 2) + 1):(size(aUAV, 1) + size(aUAV, 2)))
∂U .+= Uperp * K̇perp
∂V .+= Vᴴperp * Ṁperp
else
ImUU = (LinearAlgebra.diagm(ones(eltype(U), m)) - vU * vU')
ImVV = (LinearAlgebra.diagm(ones(eltype(Vᴴ), n)) - vV * vVᴴ)
upper = ImUU * ΔA * vV
lower = ImVV * ΔA' * vU
rhs = vcat(upper, lower)

à = ImUU * A * ImVV
ÃÃ = similar(A, (m + n, m + n))
fill!(ÃÃ, 0)
view(ÃÃ, (1:m), m .+ (1:n)) .= Ã
view(ÃÃ, m .+ (1:n), 1:m) .= Ã'

superLN = -sylvester(ÃÃ, vSmat, rhs)
∂U += view(superLN, 1:size(upper, 1), :)
∂V += view(superLN, (size(upper, 1) + 1):(size(upper, 1) + size(lower, 1)), :)
end
if !iszerotangent(ΔU)
vΔU = view(ΔU, :, 1:r)
copyto!(vΔU, ∂U)
end
if !iszerotangent(ΔVᴴ)
vΔVᴴ = view(ΔVᴴ, 1:r, :)
adjoint!(vΔVᴴ, ∂V)
end
return (ΔU, ΔS, ΔVᴴ)
end

function svd_trunc_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol = default_pullback_rank_atol(A), kwargs...)
# TODO
end

function svd_vals_pushforward!(
ΔA, A, USVᴴ, ΔS, ind = Colon();
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2])
)
ΔUSVᴴ = (nothing, diagonal(ΔS), nothing)
return svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol, degeneracy_atol)
end
2 changes: 1 addition & 1 deletion test/enzyme/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
if !is_buildkite
TestSuite.test_enzyme_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
AT = Diagonal{T, Vector{T}}
m == n && TestSuite.test_enzyme_svd(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
m == n && TestSuite.test_enzyme_svd(AT, m; atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
end
48 changes: 40 additions & 8 deletions test/testsuite/enzyme/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,48 +8,80 @@ function test_enzyme_svd(T::Type, sz; kwargs...)
end
end

"""
test_enzyme_svd_compact(T, sz; rng, atol, rtol)

Test the Enzyme forward- and reverse-mode AD rule for `svd_compact` and its in-place variant.
"""
function test_enzyme_svd_compact(
T, sz;
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
fdm = enzyme_fdm(T)
)
return @testset "svd_compact reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
return @testset "svd_compact: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
A = instantiate_matrix(T, sz)
alg = MatrixAlgebraKit.select_algorithm(svd_compact, A)
USVᴴ, ΔUSVᴴ = ad_svd_compact_setup(A)
test_reverse(svd_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm)
test_reverse(call_and_zero!, RT, (svd_compact!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm)
test_reverse(call_and_zero!, RT, (svd_compact!, Const), (copy(A), TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm)
if eltype(T) <: Real
test_forward(svd_compact, RT, (A, TA), (alg, Const); atol, rtol, fdm)
test_forward(call_and_zero!, RT, (svd_compact!, Const), (copy(A), TA), (alg, Const); atol, rtol, fdm)
end
end
end

"""
test_enzyme_svd_full(T, sz; rng, atol, rtol)

Test the Enzyme forward- and reverse-mode AD rule for `svd_full` and its in-place variant. The
gauge-dependent extra columns of `U` and rows of `Vᴴ` are zeroed out in the cotangent.
"""
function test_enzyme_svd_full(
T, sz;
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
fdm = enzyme_fdm(T)
)
return @testset "svd_full reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
return @testset "svd_full: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
A = instantiate_matrix(T, sz)
alg = MatrixAlgebraKit.select_algorithm(svd_full, A)
USVᴴ, ΔUSVᴴ = ad_svd_full_setup(A)
test_reverse(svd_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm)
test_reverse(call_and_zero!, RT, (svd_full!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm)
test_reverse(call_and_zero!, RT, (svd_full!, Const), (copy(A), TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm)
if eltype(T) <: Real
test_forward(svd_full, RT, (A, TA), (alg, Const); atol, rtol, fdm)
test_forward(call_and_zero!, RT, (svd_full!, Const), (copy(A), TA), (alg, Const); atol, rtol, fdm)
end
end
end

"""
test_enzyme_svd_vals(T, sz; rng, atol, rtol)

Test the Enzyme forward- and reverse-mode AD rule for `svd_vals` and its in-place variant.
"""
function test_enzyme_svd_vals(
T, sz;
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
fdm = enzyme_fdm(T)
)
return @testset "svd_vals reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
return @testset "svd_vals: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
A = instantiate_matrix(T, sz)
alg = MatrixAlgebraKit.select_algorithm(svd_vals, A)
S, ΔS = ad_svd_vals_setup(A)
test_reverse(svd_vals, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔS, fdm)
test_reverse(call_and_zero!, RT, (svd_vals!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔS, fdm)
test_reverse(call_and_zero!, RT, (svd_vals!, Const), (copy(A), TA), (alg, Const); atol, rtol, output_tangent = ΔS, fdm)
test_forward(svd_vals, RT, (A, TA), (alg, Const); atol, rtol, fdm)
test_forward(call_and_zero!, RT, (svd_vals!, Const), (copy(A), TA), (alg, Const); atol, rtol, fdm)
end
end

"""
test_enzyme_svd_trunc(T, sz; rng, atol, rtol)

Test the Enzyme reverse-mode AD rules for `svd_trunc`, `svd_trunc_no_error`, and their
in-place variants, over a range of truncation ranks and a tolerance-based truncation.
"""
function test_enzyme_svd_trunc(
T, sz;
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
Expand All @@ -64,15 +96,15 @@ function test_enzyme_svd_trunc(
trunc = truncrank(r)
truncalg = TruncatedAlgorithm(alg, trunc)
USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg)
test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
test_reverse(svd_trunc_no_error, RT, (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
test_reverse(call_and_zero!, RT, (svd_trunc_no_error!, Const), (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
end
@testset "trunctol" begin
S = svd_vals(A, alg)
trunc = trunctol(atol = maximum(S) / 2)
truncalg = TruncatedAlgorithm(alg, trunc)
USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg)
test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
test_reverse(svd_trunc_no_error, RT, (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
test_reverse(call_and_zero!, RT, (svd_trunc_no_error!, Const), (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
end
end
Expand Down
Loading
Loading