Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 42 additions & 7 deletions ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using MatrixAlgebraKit: diagview, inv_safe, truncate
using MatrixAlgebraKit: qr_pullback!, lq_pullback!
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!, eigh_vals_pullback!
using MatrixAlgebraKit: eig_pushforward!, eigh_pushforward!, eig_vals_pushforward!, eigh_vals_pushforward!
using MatrixAlgebraKit: svd_pullback!, svd_vals_pullback!
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward!
Expand Down Expand Up @@ -119,22 +120,28 @@ for (f, pb) in (
end

for (f, pf) in (
(left_polar!, left_polar_pushforward!),
(right_polar!, right_polar_pushforward!),
(:right_polar!, :right_polar_pushforward!),
(:left_polar!, :left_polar_pushforward!),
(:eigh_full!, :eigh_pushforward!),
(:eig_full!, :eig_pushforward!),
)
@eval begin
function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof($f)},
::Type{RT},
A::Annotation,
arg::Annotation{TA},
arg::Annotation,
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
) where {RT, TA}
) where {RT}
A_is_arg1 = !isa(A, Const) && A.val === arg.val[1]
A_is_arg2 = !isa(A, Const) && A.val === arg.val[2]
A_is_arg = A_is_arg1 || A_is_arg2
$f(A.val, arg.val, alg.val)
if !isa(A, Const) && !isa(arg, Const)
$pf(A.dval, A.val, arg.val, arg.dval)
end
!A_is_arg && make_zero!(A.dval)
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return arg
elseif EnzymeRules.needs_primal(config)
Expand Down Expand Up @@ -367,9 +374,9 @@ for (f, trunc_f, full_f, pb) in (
end
end

for (f!, f_full!, pb!) in (
(eig_vals!, eig_full!, eig_vals_pullback!),
(eigh_vals!, eigh_full!, eigh_vals_pullback!),
for (f!, f_full!, pb!, pf!) in (
(:eig_vals!, :eig_full!, :eig_vals_pullback!, :eig_vals_pushforward!),
(:eigh_vals!, :eigh_full!, :eigh_vals_pullback!, :eigh_vals_pushforward!),
)
@eval begin
function EnzymeRules.augmented_primal(
Expand Down Expand Up @@ -418,6 +425,34 @@ for (f!, f_full!, pb!) in (
!isa(D, Const) && !A_is_arg && make_zero!(D.dval)
return (nothing, nothing, nothing)
end
function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof($f!)},
::Type{RT},
A::Annotation{TA},
D::Annotation,
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
) where {RT, TA}
A_is_arg = !isa(A, Const) && TA <: Diagonal && diagview(A.dval) === D.dval
DV = $f_full!(A.val, alg.val)
Dval, V = DV
if !isa(A, Const) && !isa(D, Const)
ΔD = A_is_arg ? make_zero(D.dval) : D.dval
$pf!(A.dval, A.val, (Diagonal(diagview(Dval)), V), ΔD)
A_is_arg && (D.dval .= ΔD)
end
copyto!(D.val, diagview(Dval))
!isa(A, Const) && !A_is_arg && make_zero!(A.dval)
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return D
elseif EnzymeRules.needs_primal(config)
return D.val
elseif EnzymeRules.needs_shadow(config)
return D.dval
else
return nothing
end
end
end
end

Expand Down
37 changes: 32 additions & 5 deletions ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output, zero!
using MatrixAlgebraKit: qr_pullback!, lq_pullback!
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
using MatrixAlgebraKit: eig_pushforward!, eig_vals_pushforward!
using MatrixAlgebraKit: eigh_pushforward!, eigh_vals_pushforward!
using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pullback!
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward!
Expand Down Expand Up @@ -113,6 +115,8 @@ end
for (f!, f, pf) in (
(:left_polar!, :left_polar, :left_polar_pushforward!),
(:right_polar!, :right_polar, :right_polar_pushforward!),
(:eig_full!, :eig_full, :eig_pushforward!),
(:eigh_full!, :eigh_full, :eigh_pushforward!),
)
@eval begin
@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm}
Expand Down Expand Up @@ -177,12 +181,12 @@ for (f!, f, pb, adj) in (
end
end

for (f!, f, f_full, pb, adj) in (
(:eig_vals!, :eig_vals, :eig_full, :eig_vals_pullback!, :eig_vals_adjoint),
(:eigh_vals!, :eigh_vals, :eigh_full, :eigh_vals_pullback!, :eigh_vals_adjoint),
for (f!, f, f_full, pb, pf, adj) in (
(:eig_vals!, :eig_vals, :eig_full, :eig_vals_pullback!, :eig_vals_pushforward!, :eig_vals_adjoint),
(:eigh_vals!, :eigh_vals, :eigh_full, :eigh_vals_pullback!, :eigh_vals_pushforward!, :eigh_vals_adjoint),
)
@eval begin
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
@is_primitive Mooncake.DefaultCtx Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual)
# compute primal
A, dA = arrayify(A_dA)
Expand Down Expand Up @@ -210,7 +214,18 @@ for (f!, f, f_full, pb, adj) in (
end
return D_dD, $adj
end
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(::Dual{typeof($f!)}, A_dA::Dual, D_dD::Dual, alg_dalg::Dual)
# compute primal
A, dA = arrayify(A_dA)
D, dD = arrayify(D_dD)
# update primal
DV = $f_full(A, Mooncake.primal(alg_dalg))
V = DV[2]
copyto!(D, diagview(DV[1]))
$pf(dA, A, (D, V), dD)
return D_dD
end
@is_primitive Mooncake.DefaultCtx Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual)
# compute primal
A, dA = arrayify(A_dA)
Expand All @@ -227,6 +242,18 @@ for (f!, f, f_full, pb, adj) in (
end
return output_codual, $adj
end
function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual)
# compute primal
A, dA = arrayify(A_dA)
# update primal
DV = $f_full(A, Mooncake.primal(alg_dalg))
V = DV[2]
output = diagview(DV[1])
output_dual = Dual(output, Mooncake.zero_tangent(output))
D, dD = arrayify(output_dual)
$pf(dA, A, DV, dD)
return output_dual
end
end
end

Expand Down
2 changes: 2 additions & 0 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ include("pullbacks/svd.jl")
include("pullbacks/polar.jl")

include("pushforwards/polar.jl")
include("pushforwards/eig.jl")
include("pushforwards/eigh.jl")

include("precompile.jl")

Expand Down
22 changes: 22 additions & 0 deletions src/pushforwards/eig.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
function eig_pushforward!(
ΔA, A, DV, ΔDV;
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
)
D, V = DV
ΔD, ΔV = ΔDV
ΔAV = mul!(ΔV, ΔA, V) # reusing ΔV memory
∂K = V \ ΔAV
if !iszerotangent(ΔD)
diagview(ΔD) .= diagview(∂K)
end
if !iszerotangent(ΔV)
∂K .*= inv_safe.(transpose(diagview(D)) .- diagview(D), degeneracy_atol)
mul!(ΔV, V, ∂K, 1, 0)
end
return ΔDV
end

function eig_vals_pushforward!(ΔA, A, DV, ΔD; kwargs...)
return eig_pushforward!(ΔA, A, DV, (Diagonal(ΔD), nothing); kwargs...)
end
22 changes: 22 additions & 0 deletions src/pushforwards/eigh.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
function eigh_pushforward!(
ΔA, A, DV, ΔDV;
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
)
D, V = DV
ΔD, ΔV = ΔDV
ΔAV = mul!(ΔV, ΔA, V)
∂K = V' * ΔAV
if !iszerotangent(ΔD)
diagview(ΔD) .= real.(diagview(∂K))
end
if !iszerotangent(ΔV)
∂K .*= inv_safe.(transpose(diagview(D)) .- diagview(D), degeneracy_atol)
ΔV = mul!(ΔV, V, ∂K)
end
return (ΔD, ΔV)
end

function eigh_vals_pushforward!(ΔA, A, DV, ΔD; kwargs...)
return eigh_pushforward!(ΔA, A, DV, (Diagonal(ΔD), nothing); kwargs...)
end
2 changes: 1 addition & 1 deletion test/enzyme/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ for T in (BLASFloats..., GenericFloats...)
if !is_buildkite
TestSuite.test_enzyme_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
AT = Diagonal{T, Vector{T}}
TestSuite.test_enzyme_eig(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
TestSuite.test_enzyme_eig(AT, m; atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
end
4 changes: 2 additions & 2 deletions test/enzyme/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ m = 19
for T in (BLASFloats..., GenericFloats...)
TestSuite.seed_rng!(1234)
if !is_buildkite
#TestSuite.test_enzyme_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
TestSuite.test_enzyme_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
AT = Diagonal{T, Vector{T}}
TestSuite.test_enzyme_eigh(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
TestSuite.test_enzyme_eigh(AT, m; atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
end
18 changes: 14 additions & 4 deletions test/testsuite/enzyme/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,38 +15,48 @@ end
"""
test_enzyme_eig_full(T, sz; rng, atol, rtol)

Test the Enzyme reverse-mode AD rule for `eig_full` and its in-place variant.
Test the Enzyme foward- and reverse-mode AD rule for `eig_full` and its in-place variant.
"""
function test_enzyme_eig_full(
T, sz;
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
fdm = enzyme_fdm(T)
)
return @testset "eig_full reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
return @testset "eig_full: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
A = make_eig_matrix(T, sz)
alg = MatrixAlgebraKit.select_algorithm(eig_full, A)
DV, ΔDV = ad_eig_full_setup(A)
test_reverse(eig_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔDV, fdm)
test_reverse(call_and_zero!, RT, (eig_full!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔDV, fdm)
if eltype(T) <: Real && T <: Diagonal
A = make_eig_matrix(T, sz)
test_forward(eig_full, RT, (A, TA), (alg, Const); atol, rtol, fdm)
test_forward(call_and_zero!, RT, (eig_full!, Const), (A, TA), (alg, Const); atol, rtol, fdm)
end
end
end

"""
test_enzyme_eig_vals(T, sz; rng, atol, rtol)

Test the Enzyme reverse-mode AD rule for `eig_vals` and its in-place variant.
Test the Enzyme forward- and reverse-mode AD rule for `eig_vals` and its in-place variant.
"""
function test_enzyme_eig_vals(
T, sz;
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
fdm = enzyme_fdm(T)
)
return @testset "eig_vals reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
return @testset "eig_vals: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
A = make_eig_matrix(T, sz)
alg = MatrixAlgebraKit.select_algorithm(eig_vals, A)
D, ΔD = ad_eig_vals_setup(A)
test_reverse(eig_vals, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔD, fdm)
test_reverse(call_and_zero!, RT, (eig_vals!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔD, fdm)
if eltype(T) <: Real
A = make_eig_matrix(T, sz)
test_forward(eig_vals, RT, (A, TA), (alg, Const); atol, rtol, fdm)
test_forward(call_and_zero!, RT, (eig_vals!, Const), (A, TA), (alg, Const); atol, rtol, fdm)
end
end
end

Expand Down
16 changes: 12 additions & 4 deletions test/testsuite/enzyme/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,38 +15,46 @@ end
"""
test_enzyme_eigh_full(T, sz; rng, atol, rtol)
Test the Enzyme reverse-mode AD rule for `eigh_full` and its in-place variant.
Test the Enzyme forward- and reverse-mode AD rule for `eigh_full` and its in-place variant.
"""
function test_enzyme_eigh_full(
T, sz;
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
fdm = enzyme_fdm(T)
)
return @testset "eigh_full reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
return @testset "eigh_full: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
A = make_eigh_matrix(T, sz)
alg = MatrixAlgebraKit.select_algorithm(eigh_full, A)
DV, ΔDV = ad_eigh_full_setup(A)
test_reverse(eigh_wrapper, RT, (eigh_full, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔDV, fdm)
test_reverse(eigh!_wrapper, RT, (eigh_full!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔDV, fdm)
if eltype(T) <: Real
A = make_eigh_matrix(T, sz)
test_forward(eigh_wrapper, RT, (eigh_full, Const), (A, TA), (alg, Const); atol, rtol, fdm)
test_forward(eigh!_wrapper, RT, (eigh_full!, Const), (A, TA), (alg, Const); atol, rtol, fdm)
end
end
end

"""
test_enzyme_eigh_vals(T, sz; rng, atol, rtol)
Test the Enzyme reverse-mode AD rule for `eigh_vals` and its in-place variant.
Test the Enzyme forward- and reverse-mode AD rule for `eigh_vals` and its in-place variant.
"""
function test_enzyme_eigh_vals(
T, sz;
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
fdm = enzyme_fdm(T)
)
return @testset "eigh_vals reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
return @testset "eigh_vals: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
A = make_eigh_matrix(T, sz)
alg = MatrixAlgebraKit.select_algorithm(eigh_vals, A)
D, ΔD = ad_eigh_vals_setup(A)
test_reverse(eigh_wrapper, RT, (eigh_vals, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔD, fdm)
test_reverse(eigh!_wrapper, RT, (eigh_vals!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔD, fdm)
A = make_eigh_matrix(T, sz)
test_forward(eigh_wrapper, RT, (eigh_vals, Const), (A, TA), (alg, Const); atol, rtol, fdm)
test_forward(eigh!_wrapper, RT, (eigh_vals!, Const), (A, TA), (alg, Const); atol, rtol, fdm)
end
end

Expand Down
Loading
Loading