Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions ext/TensorKitMooncakeExt/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,43 @@ function Mooncake.rrule!!(
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_pullback
end

@is_primitive DefaultCtx ReverseMode Tuple{typeof(svd_trunc!), AbstractTensorMap, Any, MatrixAlgebraKit.AbstractAlgorithm}
Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual{<:AbstractTensorMap}, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) =
Mooncake.rrule!!(Mooncake.zero_fcodual(svd_trunc), A_dA, alg_dalg)
@is_primitive DefaultCtx ReverseMode Tuple{typeof(svd_trunc!), AbstractTensorMap, Any, MatrixAlgebraKit.TruncatedAlgorithm}
function Mooncake.rrule!!(
::CoDual{typeof(svd_trunc!)},
A_dA::CoDual{<:AbstractTensorMap},
USVᴴ_dUSVᴴ::CoDual,
alg_dalg::CoDual{<:MatrixAlgebraKit.TruncatedAlgorithm}
)
A, dA = arrayify(A_dA)
Ac = deepcopy(A)
alg = primal(alg_dalg)

USVᴴ = primal(USVᴴ_dUSVᴴ)
dUSVᴴ = tangent(USVᴴ_dUSVᴴ)
U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
USVᴴc = copy.(USVᴴ)

USVᴴ = svd_compact!(A, USVᴴ, alg.alg)
USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind)

USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual((USVᴴtrunc..., ϵ))
dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(tangent(USVᴴtrunc_dUSVᴴtrunc))))

function svd_trunc_pullback((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real})
abs(dϵ) ≤ MatrixAlgebraKit.defaulttol(dϵ) ||
@warn "Gradient for `svd_trunc` ignores non-zero tangents for truncation error"
MatrixAlgebraKit.svd_pullback!(dA, Ac, USVᴴ, dUSVᴴtrunc, ind)
# restore state
copy!(A, Ac)
copy!.(USVᴴ, USVᴴc)
MatrixAlgebraKit.zero!(dU)
MatrixAlgebraKit.zero!(dS)
MatrixAlgebraKit.zero!(dVᴴ)
return NoRData(), NoRData(), NoRData(), NoRData()
end

return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_pullback
end
17 changes: 17 additions & 0 deletions ext/TensorKitMooncakeExt/tensoroperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,20 @@ function trace_permute_pullback_ΔA!(
)
return NoRData()
end

@is_primitive(
DefaultCtx,
Tuple{
typeof(TensorKit.scalar),
AbstractTensorMap,
}
)
function Mooncake.rrule!!(::CoDual{typeof(TensorKit.scalar)}, t_dt::CoDual{<:AbstractTensorMap})
t, dt = arrayify(t_dt)
val = scalar(t)
function scalar_pullback(Δval)
first(blocks(dt))[2][1] = Δval
return NoRData(), NoRData()
end
return Mooncake.zero_fcodual(val), scalar_pullback
end
6 changes: 6 additions & 0 deletions ext/TensorKitMooncakeExt/utility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ Mooncake.tangent_type(::Type{<:HomSpace}) = Mooncake.NoTangent
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.sectorstructure), Any}
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.degeneracystructure), Any}

@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorstructure), AbstractTensorMap}
@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorstructure), AbstractTensorMap, Int, Bool}

@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorcontract_structure), AbstractTensorMap, Index2Tuple, Bool, AbstractTensorMap, Index2Tuple, Bool, Index2Tuple}

@zero_derivative DefaultCtx Tuple{typeof(TensorKit.has_shared_permute), AbstractTensorMap, Index2Tuple}
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.select), HomSpace, Index2Tuple}
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.flip), HomSpace, Any}
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.permute), HomSpace, Index2Tuple}
Expand Down
35 changes: 22 additions & 13 deletions test/mooncake/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ using MatrixAlgebraKit: remove_qr_gauge_dependence!, remove_lq_gauge_dependence!
using Mooncake
using Random

function call_and_zero!(f!, A, alg)
F′ = f!(A, alg)
MatrixAlgebraKit.zero!(A)
return F′
end

mode = Mooncake.ReverseMode
rng = Random.default_rng()
Expand All @@ -18,7 +23,6 @@ eltypes = (Float64, ComplexF64)
@timedtestset "Mooncake - Factorizations: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes
atol = default_tol(T)
rtol = default_tol(T)

@timedtestset "QR" begin
A = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2])

Expand All @@ -29,8 +33,7 @@ eltypes = (Float64, ComplexF64)
ΔQR = Mooncake.randn_tangent(rng, QR)
remove_qr_gauge_dependence!(ΔQR..., A, QR...)
Mooncake.TestUtils.test_rule(rng, qr_full, A; output_tangent = ΔQR, atol, rtol, mode, is_primitive = false)
# TODO:
# Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false)
#Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false)

A = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← (V[4] ⊗ V[5])')

Expand All @@ -41,34 +44,31 @@ eltypes = (Float64, ComplexF64)
ΔQR = Mooncake.randn_tangent(rng, QR)
remove_qr_gauge_dependence!(ΔQR..., A, QR...)
Mooncake.TestUtils.test_rule(rng, qr_full, A; output_tangent = ΔQR, atol, rtol, mode, is_primitive = false)
# TODO:
# Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false)
#Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false)
end

@timedtestset "LQ" begin
A = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2])

Mooncake.TestUtils.test_rule(rng, lq_compact, A; atol, rtol, mode, is_primitive = false)

# qr_full/qr_null requires being careful with gauges
# lq_full/lq_null requires being careful with gauges
LQ = lq_full(A)
ΔLQ = Mooncake.randn_tangent(rng, LQ)
remove_lq_gauge_dependence!(ΔLQ..., A, LQ...)
Mooncake.TestUtils.test_rule(rng, lq_full, A; output_tangent = ΔLQ, atol, rtol, mode, is_primitive = false)
# TODO:
# Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false)
#Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false)

A = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])')

Mooncake.TestUtils.test_rule(rng, lq_compact, A; atol, rtol, mode, is_primitive = false)

# qr_full/qr_null requires being careful with gauges
# lq_full/lq_null requires being careful with gauges
LQ = lq_full(A)
ΔLQ = Mooncake.randn_tangent(rng, LQ)
remove_lq_gauge_dependence!(ΔLQ..., A, LQ...)
Mooncake.TestUtils.test_rule(rng, lq_full, A; output_tangent = ΔLQ, atol, rtol, mode, is_primitive = false)
# TODO:
# Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false)
#Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false)
end

@timedtestset "Eigenvalue decomposition" begin
Expand All @@ -88,7 +88,7 @@ eltypes = (Float64, ComplexF64)

@timedtestset "Singular value decomposition" begin
for t in (randn(T, V[1] ← V[1]), randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])'))
USVᴴ = svd_compact(t)
#=USVᴴ = svd_compact(t)
ΔUSVᴴ = Mooncake.randn_tangent(rng, USVᴴ)
remove_svd_gauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...)
Mooncake.TestUtils.test_rule(rng, svd_compact, t; output_tangent = ΔUSVᴴ, atol, rtol, mode, is_primitive = false)
Expand All @@ -97,14 +97,23 @@ eltypes = (Float64, ComplexF64)
ΔUSVᴴ = Mooncake.randn_tangent(rng, USVᴴ)
remove_svd_gauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...)
Mooncake.TestUtils.test_rule(rng, svd_full, t; output_tangent = ΔUSVᴴ, atol, rtol, mode, is_primitive = false)

=#
V_trunc = spacetype(t)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t))
trunc = truncspace(V_trunc)
alg = MatrixAlgebraKit.select_algorithm(svd_trunc, t, nothing; trunc)
USVᴴtrunc = svd_trunc(t, alg)
ΔUSVᴴtrunc = (Mooncake.randn_tangent(rng, Base.front(USVᴴtrunc))..., zero(last(USVᴴtrunc)))
remove_svd_gauge_dependence!(ΔUSVᴴtrunc[1], ΔUSVᴴtrunc[3], Base.front(USVᴴtrunc)...)
Mooncake.TestUtils.test_rule(rng, svd_trunc, t, alg; output_tangent = ΔUSVᴴtrunc, atol, rtol, mode)

V_trunc = spacetype(t)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t))
trunc = truncspace(V_trunc)
USVᴴ = svd_compact(t)
alg = MatrixAlgebraKit.select_algorithm(svd_trunc, t, nothing; trunc)
USVᴴtrunc = svd_trunc(t, alg)
ΔUSVᴴtrunc = (Mooncake.randn_tangent(rng, Base.front(USVᴴtrunc))..., zero(last(USVᴴtrunc)))
remove_svd_gauge_dependence!(ΔUSVᴴtrunc[1], ΔUSVᴴtrunc[3], Base.front(USVᴴtrunc)...)
Mooncake.TestUtils.test_rule(rng, call_and_zero!, svd_trunc!, t, alg; output_tangent = ΔUSVᴴtrunc, atol, rtol, mode, is_primitive = false)
end
end
end
Loading