Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
Expand All @@ -31,6 +33,8 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
TensorKitAMDGPUExt = "AMDGPU"
TensorKitCUDAExt = ["CUDA", "cuTENSOR"]
TensorKitChainRulesCoreExt = "ChainRulesCore"
TensorKitEnzymeExt = "Enzyme"
TensorKitEnzymeTestUtilsExt = "EnzymeTestUtils"
TensorKitFiniteDifferencesExt = "FiniteDifferences"
TensorKitMooncakeExt = "Mooncake"

Expand All @@ -43,19 +47,21 @@ AMDGPU = "2"
CUDA = "6"
ChainRulesCore = "1"
Dictionaries = "0.4"
Enzyme = "0.13.146"
EnzymeTestUtils = "0.2.7"
FiniteDifferences = "0.12"
LRUCache = "1.0.2"
LinearAlgebra = "1"
MatrixAlgebraKit = "0.6.7"
MatrixAlgebraKit = "0.6.8"
Mooncake = "0.5.27"
OhMyThreads = "0.8.0"
Printf = "1"
Random = "1"
ScopedValues = "1.3.0"
Strided = "2"
TensorKitSectors = "0.3.7"
TensorOperations = "5.5"
TensorOperations = "5.5.2"
TupleTools = "1.5"
VectorInterface = "0.4.8, 0.5, 0.6"
VectorInterface = "0.4.8, 0.5"
cuTENSOR = "6"
julia = "1.10"
16 changes: 16 additions & 0 deletions ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module TensorKitEnzymeExt

using Enzyme
using TensorKit
import TensorKit as TK
using VectorInterface
using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize
import TensorOperations as TO
using MatrixAlgebraKit
using TupleTools
using Random: AbstractRNG

include("utility.jl")
include("linalg.jl")

end
262 changes: 262 additions & 0 deletions ext/TensorKitEnzymeExt/linalg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
# Shared
# ------
# Can Enzyme do this itself? Apparently not...
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(mul!)},
::Type{RT},
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
B::Annotation{<:AbstractTensorMap},
α::Annotation,
β::Annotation,
) where {RT}
cacheC = !isa(β, Const) && copy(C.val)
cacheA = !isa(B, Const) && EnzymeRules.overwritten(config)[3] ? copy(A.val) : nothing
cacheB = !isa(A, Const) && EnzymeRules.overwritten(config)[4] ? copy(B.val) : nothing
AB = if !isa(α, Const)
AB = A.val * B.val
add!(C.val, AB, α.val, β.val)
AB
else
mul!(C.val, A.val, B.val, α.val, β.val)
nothing
end
primal = EnzymeRules.needs_primal(config) ? C.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing
cache = (cacheC, cacheA, cacheB, AB)
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(mul!)},
::Type{RT},
cache,
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
B::Annotation{<:AbstractTensorMap},
α::Annotation{<:Number},
β::Annotation{<:Number},
) where {RT}
if RT <: Const
Δα = isa(α, Const) ? nothing : zero(α.val)
Δβ = isa(β, Const) ? nothing : zero(β.val)
return (nothing, nothing, nothing, Δα, Δβ)
end
cacheC, cacheA, cacheB, AB = cache
Cval = something(cacheC, C.val)
Aval = something(cacheA, A.val)
Bval = something(cacheB, B.val)

!isa(A, Const) && !isa(C, Const) && project_mul!(A.dval, C.dval, Bval', conj(α.val))
!isa(B, Const) && !isa(C, Const) && project_mul!(B.dval, Aval', C.dval, conj(α.val))
Δαr = pullback_dα(α, C, AB)
Δβr = pullback_dβ(β, C, Cval)
!isa(C, Const) && pullback_dC!(C.dval, β.val)

return (nothing, nothing, nothing, Δαr, Δβr)
end

function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof(mul!)},
::Type{RT},
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
B::Annotation{<:AbstractTensorMap},
α::Annotation{<:Number},
β::Annotation{<:Number},
) where {RT}
# ΔC′ = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α
if !isa(C, Const)
scale!(C.dval, β.val)
!isa(β, Const) && add!(C.dval, C.val, β.dval)
!isa(α, Const) && project_mul!(C.dval, A.val, B.val, α.dval)
!isa(A, Const) && project_mul!(C.dval, A.dval, B.val, α.val)
!isa(B, Const) && project_mul!(C.dval, A.val, B.dval, α.val)
end
mul!(C.val, A.val, B.val, α.val, β.val)
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return C
elseif EnzymeRules.needs_primal(config)
return C.val
elseif EnzymeRules.needs_shadow(config)
return C.dval
else
return nothing
end
end

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(tr)},
::Type{RT},
A::Annotation{<:AbstractTensorMap},
) where {RT}
ret = func.val(A.val)
primal = EnzymeRules.needs_primal(config) ? ret : nothing
shadow = EnzymeRules.needs_shadow(config) ? zero(ret) : nothing
cache = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(tr)},
dret::Active,
cache,
A::Annotation{<:AbstractTensorMap},
)
Aval = something(cache, A.val)
Δtrace = dret.val
if !isa(A, Const)
for (_, b) in blocks(A.dval)
TensorKit.diagview(b) .+= Δtrace
end
end
return (nothing,)
end
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(tr)},
::Type{<:Const},
cache,
A::Annotation{<:AbstractTensorMap},
)
return (nothing,)
end
function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1},
::Type{RT},
func::Const{typeof(tr)},
A::Annotation{<:AbstractTensorMap},
) where {RT}
y = EnzymeRules.needs_primal(config) ? tr(A.val) : nothing
Δy = if EnzymeRules.needs_shadow(config) && !isa(A, Const)
tr(A.dval)
elseif EnzymeRules.needs_shadow(config)
zero(eltype(A.dval))
else
nothing
end
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return Duplicated(y, Δy)
elseif EnzymeRules.needs_primal(config)
return y
elseif EnzymeRules.needs_shadow(config)
return Δy
else
return nothing
end
end
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(norm)},
::Type{RT},
A::Annotation{<:AbstractTensorMap},
p::Const{<:Real},
) where {RT}
p.val == 2 || error("currently only implemented for p = 2")
ret = func.val(A.val, p.val)
primal = EnzymeRules.needs_primal(config) ? ret : nothing
shadow = EnzymeRules.needs_shadow(config) ? zero(ret) : nothing
cacheA = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing
cache = (ret, cacheA)
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(norm)},
dret::Active,
cache,
A::Annotation{<:AbstractTensorMap},
p::Const{<:Real},
)
n, cacheA = cache
Δn = dret.val
p.val == 2 || error("currently only implemented for p = 2")
Aval = something(cacheA, A.val)
if !isa(A, Const)
x = (Δn' + Δn) / 2 / hypot(n, eps(one(n)))
add!(A.dval, A.val, x)
end
return (nothing, nothing)
end
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(norm)},
::Type{<:Const},
cache,
A::Annotation{<:AbstractTensorMap},
p::Const{<:Real},
)
return (nothing, nothing)
end
function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof(norm)},
::Type{RT},
A::Annotation{<:AbstractTensorMap},
p::Const{<:Real},
) where {RT}
y = norm(A.val, p.val)
Δy = if EnzymeRules.needs_shadow(config) && !isa(A, Const)
real(dot(A.val, A.dval)) * pinv(y)
elseif EnzymeRules.needs_shadow(config)
zero(eltype(A.dval))
else
nothing
end
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return Duplicated(y, Δy)
elseif EnzymeRules.needs_primal(config)
return y
elseif EnzymeRules.needs_shadow(config)
return Δy
else
return nothing
end
end
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(inv)},
::Type{RT},
A::Annotation{<:AbstractTensorMap},
) where {RT}
ret = inv(A.val)
primal = EnzymeRules.needs_primal(config) ? ret : nothing
shadow = EnzymeRules.needs_shadow(config) ? make_zero(ret) : nothing
cache = (ret, shadow)
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(inv)},
::Type{RT},
cache,
A::Annotation{<:AbstractTensorMap},
) where {RT}
Ainv, ΔAinv = cache
!isa(A, Const) && mul!(A.dval, Ainv' * ΔAinv, Ainv', -1, One())
return (nothing,)
end

function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof(inv)},
::Type{RT},
A::Annotation{<:AbstractTensorMap},
) where {RT}
Ainv = inv(A.val)
ΔAinv = !isa(A, Const) ? scale!(Ainv * A.dval * Ainv, -1) : make_zero(Ainv)
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return Duplicated(Ainv, ΔAinv)
elseif EnzymeRules.needs_primal(config)
return Ainv
elseif EnzymeRules.needs_shadow(config)
return ΔAinv
else
return nothing
end
end
Loading
Loading