diff --git a/src/algorithms/contractions/bondenv/als_solve.jl b/src/algorithms/contractions/bondenv/als_solve.jl index af083f618..c2ee27222 100644 --- a/src/algorithms/contractions/bondenv/als_solve.jl +++ b/src/algorithms/contractions/bondenv/als_solve.jl @@ -2,139 +2,168 @@ In the following, the names `Ra`, `Sa` etc comes from the fast full update article Physical Review B 92, 035142 (2015) =# - """ -$(SIGNATURES) - -Construct the tensor +Contract the virtual legs between ``` - ┌-----------------------------------┐ - | ┌----┐ | - └---| |- DX0 Db0 - b -- DY0 -┘ - | | ↓ - |benv| db - | | ↓ - ┌---| |- DX1 Db1 - b† - DY1 -┐ - | └----┘ | - └-----------------------------------┘ + -- DX --a-- D --b-- DY -- + ↓ ↓ + da db ``` """ -function _tensor_Ra(benv::BondEnv, b::MPSTensor) - return @autoopt @tensor Ra[DX1 Db1; DX0 Db0] := ( - benv[DX1 DY1; DX0 DY0] * b[Db0 db; DY0] * conj(b[Db1 db; DY1]) - ) +function _combine_ket(a::MPSTensor, b::AbstractTensorMap{T, S, 1, 2}) where {T, S} + return @tensor ket[DX DY; da db] := a[DX da; D] * b[D; db DY] +end +function _combine_ket(a::MPSTensor, b::MPSTensor) + return @tensor ket[DX DY; da db] := a[DX da; D] * b[D db; DY] end -""" -$(SIGNATURES) +function _combine_ket_for_svd(a::MPSTensor, b::MPSTensor) + return @tensor ket[DX da; db DY] := a[DX da; D] * b[D db; DY] +end -Construct the tensor +""" +Construct the norm with bra bond tensors removed ``` - ┌-----------------------------------┐ - | ┌----┐ | - └---| |- DX0 -- (a2 b2) -- DY0 --┘ - | | ↓ ↓ - |benv| da db - | | ↓ - ┌---| |- DX1 Db1 -- b† - DY1 --┐ - | └----┘ | - └-----------------------------------┘ + ┌benv-------┐ + ├---a---b---┤ + | ↓ ↓ | + ├-- --┤ + └-----------┘ ``` """ -function _tensor_Sa( - benv::BondEnv, b::MPSTensor, a2b2::AbstractTensorMap{T, S, 2, 2} - ) where {T <: Number, S <: ElementarySpace} - return @autoopt @tensor Sa[DX1 da; Db1] := ( - benv[DX1 DY1; DX0 DY0] * conj(b[Db1 db; DY1]) * a2b2[DX0 DY0; da db] - ) +function _benv_ket(benv::BondEnv, ket::AbstractTensorMap{T, S, 2, 2}) where {T, S} + return benv * twistdual(ket, 1:2) end """ -$(SIGNATURES) + _als_tensor_R(benv::BondEnv, xs::Vector{<:MPSTensor}, i::Int) -Construct the tensor +Construct the bond environment around the `i`th bond tensor +in two-site ALS optimization. ``` - ┌-----------------------------------┐ - | ┌----┐ | - └---| |- DX0 - a -- Da0 DY0 -┘ - | | ↓ - |benv| da - | | ↓ - ┌---| |- DX1 - a† - Da1 DY1 -┐ - | └----┘ | - └-----------------------------------┘ + i = 1 i = 2 + ┌benv-------┐ ┌benv-------┐ + ├-- --b---┤ ├---a-- --┤ + | ↓ | | ↓ | + ├-- --b̄---┤ ├---ā-- --┤ + └-----------┘ └-----------┘ ``` """ -function _tensor_Rb(benv::BondEnv, a::MPSTensor) - return @autoopt @tensor Rb[Da1 DY1; Da0 DY0] := ( - benv[DX1 DY1; DX0 DY0] * a[DX0 da; Da0] * conj(a[DX1 da; Da1]) - ) +function _als_tensor_R(benv::BondEnv, xs::Vector{<:MPSTensor}, i::Int) + @assert 1 <= i <= 2 + return if i == 1 + _als_tensor_Ra(benv, xs[2]) + else + _als_tensor_Rb(benv, xs[1]) + end end -""" -$(SIGNATURES) +function _als_tensor_Ra(benv::BondEnv, b::MPSTensor) + return @tensor Ra[DX1 D1; DX0 D0] := + benv[DX1 DY1; DX0 DY0] * b[D0 db; DY0] * conj(b[D1 db; DY1]) +end +function _als_tensor_Rb(benv::BondEnv, a::MPSTensor) + return @tensor Rb[D1 DY1; D0 DY0] := + benv[DX1 DY1; DX0 DY0] * a[DX0 da; D0] * conj(a[DX1 da; D1]) +end -Construct the tensor +""" +Calculate the 2-site norm ``` - ┌-----------------------------------┐ - | ┌----┐ | - └---| |- DX0 -- (a2 b2) -- DY0 --┘ - | | ↓ ↓ - |benv| da db - | | ↓ - ┌---| |- DX1 -- a† - Da1 DY1 --┐ - | └----┘ | - └-----------------------------------┘ + ┌benv-------┐ + ├---a---b---┤ + | ↓ ↓ | + ├---ā---b̄---┤ + └-----------┘ ``` +using pre-calcuated partial contraction results. """ -function _tensor_Sb( - benv::BondEnv, a::MPSTensor, a2b2::AbstractTensorMap{T, S, 2, 2} - ) where {T <: Number, S <: ElementarySpace} - return @autoopt @tensor Sb[Da1 db; DY1] := ( - benv[DX1 DY1; DX0 DY0] * conj(a[DX1 da; Da1]) * a2b2[DX0 DY0; da db] - ) +function _als_norm( + ket::AbstractTensorMap{T, S, 2, 2}, benv_ket::AbstractTensorMap{T, S, 2, 2} + ) where {T, S} + return @tensor benv_ket[DX1 DY1; da db] * conj(ket[DX1 DY1; da db]) +end +function _als_norm(a::MPSTensor, Ra::BondEnv) + return @tensor Ra[DX1 D1; DX0 D0] * a[DX0 da; D0] * conj(a[DX1 da; D1]) end """ -$(SIGNATURES) + _als_tensor_S( + benv_ket::AbstractTensorMap{T, S, 2, 2}, + xs::Vector{<:MPSTensor}, i::Int + ) where {T <: Number, S <: ElementarySpace} -Calculate the inner product +Construct the overlap but with one of the bra bond tensor removed. ``` - ┌--------------------------------┐ - | ┌----┐ | - └---| |- DX0 - (a2 b2) - DY0 -┘ - | | ↓ ↓ - |benv| da db - | | ↓ ↓ - ┌---| |- DX1 - (a1 b1)†- DY1 -┐ - | └----┘ | - └--------------------------------┘ + i = 1 i = 2 + ┌benv-------┐ ┌benv-------┐ + ├---a₂==b₂--┤ ├---a₂==b₂--┤ + | ↓ ↓ | | ↓ ↓ | + ├-- --b̄---┤ ├---ā-- --┤ + └-----------┘ └-----------┘ ``` +The ket part is provided by the partial contraction `benv_ket`. """ -function inner_prod( - benv::BondEnv, a1b1::AbstractTensorMap{T, S, 2, 2}, a2b2::AbstractTensorMap{T, S, 2, 2} +function _als_tensor_S( + benv_ket::AbstractTensorMap{T, S, 2, 2}, + xs::Vector{<:MPSTensor}, i::Int ) where {T <: Number, S <: ElementarySpace} - return @autoopt @tensor benv[DX1 DY1; DX0 DY0] * - conj(a1b1[DX1 DY1; da db]) * a2b2[DX0 DY0; da db] + @assert 1 <= i <= 2 + return if i == 1 + _als_tensor_Sa(benv_ket, xs[2]) + else + _als_tensor_Sb(benv_ket, xs[1]) + end +end + +function _als_tensor_Sa( + benv_ket::AbstractTensorMap{T, S, 2, 2}, b::MPSTensor + ) where {T <: Number, S <: ElementarySpace} + return @tensor Sa[DX1 da; D1] := + benv_ket[DX1 DY1; da db] * conj(b[D1 db; DY1]) +end +function _als_tensor_Sb( + benv_ket::AbstractTensorMap{T, S, 2, 2}, a::MPSTensor + ) where {T <: Number, S <: ElementarySpace} + return @tensor Sb[D1 db; DY1] := + benv_ket[DX1 DY1; da db] * conj(a[DX1 da; D1]) end """ -$(SIGNATURES) +Calculate the inner product (overlap) +``` + ┌benv-------┐ + ├---a₂--b₂--┤ + | ↓ ↓ | + ├---ā---b̄---┤ + └-----------┘ +``` +using pre-calculated partial contraction results. +""" +function _als_overlap(a::MPSTensor, Sa::MPSTensor) + # applies to b, Sb as well + # @tensor Sb[D1 db; DY1] * conj(b[D1 db; DY1]) + return @tensor Sa[DX1 da; D1] * conj(a[DX1 da; D1]) +end -Contract the axis between `a` and `b` tensors +""" +Calculate the 2-site ALS inner product ⟨a₁,b₁|a₂,b₂⟩ ``` - -- DX - a - D - b - DY -- - ↓ ↓ - da db + ┌benv-------┐ + ├---a₂--b₂--┤ + | ↓ ↓ | + ├---ā₁--b̄₁--┤ + └-----------┘ ``` +where `|bra⟩ = |a₁,b₁⟩` and `|ket⟩ = |a₂,b₂⟩`, +with virtual leg between a, b contracted. """ -function _combine_ab( - a::MPSTensor, b::AbstractTensorMap{T, S, 1, 2} +function inner_prod( + benv::BondEnv, bra::AbstractTensorMap{T, S, 2, 2}, + ket::AbstractTensorMap{T, S, 2, 2} ) where {T <: Number, S <: ElementarySpace} - return @tensor ab[DX DY; da db] := a[DX da; D] * b[D; db DY] -end -function _combine_ab(a::MPSTensor, b::MPSTensor) - return @tensor ab[DX DY; da db] := a[DX da; D] * b[D db; DY] + return @autoopt @tensor benv[DX1 DY1; DX0 DY0] * + conj(bra[DX1 DY1; da db]) * ket[DX0 DY0; da db] end """ @@ -161,10 +190,30 @@ function cost_function_als(benv, ψ1, ψ2) return cost, fid end +# applies to Rb, Sb, b as well +# b22 is the pre-calculated untruncated norm +function cost_function_als(Ra::BondEnv, Sa::MPSTensor, a::MPSTensor, b22::Real) + b11 = real(_als_norm(a, Ra)) + b12 = _als_overlap(a, Sa) + cost = b11 + b22 - 2 * real(b12) + fid = abs2(b12) / abs(b11 * b22) + return cost, fid +end + """ $(SIGNATURES) Solve the equations `Rx x = Sx` with initial guess `x0`. + +In ALS over `a`, `b`, if we fix `b`, the cost function can +be expressed in the `Ra`, `Sa` tensors as +``` + f(a†,a) = a† Ra a - a† Sa - Sa† a + const +``` +Therefore `f` is minimized when +``` + ∂f/∂ā = Ra a - Sa = 0 +``` """ function _solve_als( Rx::AbstractTensorMap{T, S, N, N}, diff --git a/src/algorithms/truncation/bond_truncation.jl b/src/algorithms/truncation/bond_truncation.jl index 57fd68cc7..70e5af80a 100644 --- a/src/algorithms/truncation/bond_truncation.jl +++ b/src/algorithms/truncation/bond_truncation.jl @@ -18,8 +18,8 @@ The truncation algorithm can be constructed from the following keyword arguments * `tol::Float64=1e-9` : ALS converges when the relative change in bond SVD spectrum between two iterations is smaller than `tol`. * `check_interval::Int=0` : Set number of iterations to print information. Output is suppressed when `check_interval <= 0`. """ -@kwdef struct ALSTruncation - trunc::TruncationStrategy +@kwdef struct ALSTruncation{T <: TruncationStrategy} + trunc::T maxiter::Int = 50 tol::Float64 = 1.0e-9 check_interval::Int = 0 @@ -34,6 +34,20 @@ function _als_message( ) * @sprintf(" cost = %.3e, Δcost/cost0 = %.3e, |Δs| = %.4e.", cost, Δcost, Δs) end +""" +Initialize truncated bond tensors for 2-site ALS +""" +function _als_init_truncate( + ket2::AbstractTensorMap{T, S, 2, 2}, trunc::TruncationStrategy + ) where {T, S} + a, s0, b = svd_trunc!(permute(ket2, ((1, 3), (4, 2)); copy = true); trunc) + a, b = absorb_s(a, s0, b) + # put b in MPS axis order + b = permute(b, ((1, 2), (3,))) + xs = [a, b] + return xs, s0 +end + """ bond_truncate(a::AbstractTensorMap{T,S,2,1}, b::AbstractTensorMap{T,S,1,2}, benv::BondEnv{T,S}, alg) -> U, S, V, info @@ -69,40 +83,39 @@ function bond_truncate( need_flip = isdual(space(b, 1)) time00 = time() verbose = (alg.check_interval > 0) - a2b2 = _combine_ab(a, b) - # initialize truncated a, b - perm_ab = ((1, 3), (4, 2)) - a, s0, b = svd_trunc(permute(a2b2, perm_ab); trunc = alg.trunc) - a, b = absorb_s(a, s0, b) - # put b in MPS axis order - b = permute(b, ((1, 2), (3,))) - ab = _combine_ab(a, b) + + # untruncated things + ket2 = _combine_ket(a, b) + benv_ket2 = _benv_ket(benv, ket2) + b22 = real(_als_norm(ket2, benv_ket2)) + + # initialize truncated bond tensors and bond weight + xs, s0 = _als_init_truncate(ket2, alg.trunc) + + # initialize ALS cache + Rs = [_als_tensor_R(benv, xs, i) for i in 1:2] + Ss = [_als_tensor_S(benv_ket2, xs, i) for i in 1:2] + # cost function will be normalized by initial value - cost00, fid = cost_function_als(benv, ab, a2b2) + cost00, fid = cost_function_als(Rs[1], Ss[1], xs[1], b22) cost0, fid0, Δcost, Δfid, Δs = cost00, fid, NaN, NaN, NaN verbose && @info "ALS init" * _als_message(0, cost0, fid, Δcost, Δfid, Δs, 0.0) + for iter in 1:(alg.maxiter) time0 = time() - #= - Fixing `b`, the cost function can be expressed in the R, S tensors as - ``` - f(a†,a) = a† Ra a - a† Sa - Sa† a + const - ``` - `f` is minimized when - ∂f/∂ā = Ra a - Sa = 0 - =# - Ra = _tensor_Ra(benv, b) - Sa = _tensor_Sa(benv, b, a2b2) - a, info_a = _solve_als(Ra, Sa, a) - # Fixing `a`, solve for `b` from `Rb b = Sb` - Rb = _tensor_Rb(benv, a) - Sb = _tensor_Sb(benv, a, a2b2) - b, info_b = _solve_als(Rb, Sb, b) - @debug "Bond truncation info" info_a info_b - ab = _combine_ab(a, b) - cost, fid = cost_function_als(benv, ab, a2b2) + for (i, (Rx, Sx, x)) in enumerate(zip(Rs, Ss, xs)) + # TODO: option to use pinv + xs[i], info_x = _solve_als(Rx, Sx, x) + @debug "Bond truncation info $(i):" info_x + # update R, S for the next site + i_next = _next(i, 2) + Rs[i_next] = _als_tensor_R(benv, xs, i_next) + Ss[i_next] = _als_tensor_S(benv_ket2, xs, i_next) + end + # cost function and local fidelity + cost, fid = cost_function_als(Rs[1], Ss[1], xs[1], b22) # TODO: replace with truncated svdvals (without calculating u, vh) - _, s, _ = svd_trunc!(permute(ab, perm_ab); trunc = alg.trunc) + _, s, _ = svd_trunc!(_combine_ket_for_svd(xs...); trunc = alg.trunc) # fidelity, cost and normalized bond-s change s_nrm = norm(s0, Inf) Δs = _singular_value_distance(s, s0) / s_nrm @@ -129,7 +142,7 @@ function bond_truncate( end converge && break end - a, s, b = svd_trunc!(permute(_combine_ab(a, b), perm_ab); trunc = alg.trunc) + a, s, b = svd_trunc!(_combine_ket_for_svd(xs...); trunc = alg.trunc) a, b = absorb_s(a, s, b) if need_flip a, s, b = flip(a, numind(a)), _fliptwist_s(s), flip(b, 1) diff --git a/src/algorithms/truncation/fullenv_truncation.jl b/src/algorithms/truncation/fullenv_truncation.jl index fd9c685bd..4ff6bac2b 100644 --- a/src/algorithms/truncation/fullenv_truncation.jl +++ b/src/algorithms/truncation/fullenv_truncation.jl @@ -23,8 +23,8 @@ The truncation algorithm can be constructed from the following keyword arguments * [Glen Evenbly, Phys. Rev. B 98, 085155 (2018)](@cite evenbly_gauge_2018). """ -@kwdef struct FullEnvTruncation - trunc::TruncationStrategy +@kwdef struct FullEnvTruncation{T <: TruncationStrategy} + trunc::T maxiter::Int = 50 tol::Float64 = 1.0e-9 trunc_init::Bool = true @@ -75,13 +75,13 @@ function _fet_message( end """ - fullenv_truncate(benv::BondEnv{T,S}, b0::AbstractTensorMap{T,S,1,1}, alg::FullEnvTruncation) -> U, S, V, info + fullenv_truncate(b0, benv::BondEnv, alg::FullEnvTruncation) -> U, S, V, info Perform full environment truncation algorithm from [Phys. Rev. B 98, 085155 (2018)](@cite evenbly_gauge_2018) on `benv`. -Given a fixed state `|b0⟩` with bond matrix `b0` -and the corresponding positive-definite bond environment `benv`, +Given a fixed state `|b0⟩` with bond matrix `b0` and the +corresponding positive-definite bond environment `benv`, find the state `|b⟩` with truncated bond matrix `b = u s v†` that maximizes the fidelity (not normalized by `⟨b0|b0⟩`) ``` @@ -215,11 +215,12 @@ function fullenv_truncate( b1 = similar(b0) s0 = deepcopy(s) Δfid, Δs, fid, fid0 = NaN, NaN, 0.0, 0.0 + @tensor benv_b0[-1 -2] := benv[-1 -2; 3 4] * b0[3; 4] for iter in 1:(alg.maxiter) time0 = time() # update `← r - = ← s ← v† -` @tensor r[-1 -2] := s[-1; 1] * vh[1; -2] - @tensor p[-1 -2] := conj(u[1; -1]) * benv[1 -2; 3 4] * b0[3; 4] + @tensor p[-1 -2] := conj(u[1; -1]) * benv_b0[1 -2] @tensor B[-1 -2; -3 -4] := conj(u[1; -1]) * benv[1 -2; 3 -4] * u[3; -3] _linearmap_twist!(p) _linearmap_twist!(B) @@ -228,7 +229,7 @@ function fullenv_truncate( u, s, vh = svd_trunc(b1; trunc = alg.trunc) # update `- l ← = - u ← s ←` @tensor l[-1 -2] := u[-1; 1] * s[1; -2] - @tensor p[-1 -2] := conj(vh[-2; 2]) * benv[-1 2; 3 4] * b0[3; 4] + @tensor p[-1 -2] := conj(vh[-2; 2]) * benv_b0[-1 2] @tensor B[-1 -2; -3 -4] := conj(vh[-2; 2]) * benv[-1 2; -3 4] * vh[-4; 4] _linearmap_twist!(p) _linearmap_twist!(B) diff --git a/test/bondenv/bond_truncate.jl b/test/bondenv/bond_truncate.jl index 25f56c200..d145f1ad4 100644 --- a/test/bondenv/bond_truncate.jl +++ b/test/bondenv/bond_truncate.jl @@ -5,40 +5,49 @@ using TensorKit using PEPSKit using LinearAlgebra using KrylovKit -using PEPSKit: cost_function_als +using PEPSKit: bond_truncate, cost_function_als +using PEPSKit: _combine_ket, _combine_ket_for_svd Random.seed!(0) maxiter = 600 -check_interval = 20 -trunc = truncerror(; atol = 1.0e-10) & truncrank(8) -Vext = Vect[FermionParity](0 => 100, 1 => 100) -Vint = Vect[FermionParity](0 => 6, 1 => 6) -Vphy = Vect[FermionParity](0 => 1, 1 => 2) -perm_ab = ((1, 3), (4, 2)) -for Vbondl in (Vint, Vint'), Vbondr in (Vint, Vint') - Vbond = Vbondl ⊗ Vbondr +check_interval = 30 +elt = ComplexF64 +# simulating the situation of applying a 2-site gate +# to a bond with virtual dimension D, physical dimension d. +d, D = 2, 4 +trunc = truncerror(; atol = 1.0e-10) & truncrank(D) +Vphy = Vect[FermionParity](0 => div(d, 2), 1 => div(d, 2)) +Vqro = Vect[FermionParity](0 => div(d * D, 2), 1 => div(d * D, 2)) +# virtual dimension of gate MPO is d^2 +Vint = Vect[FermionParity](0 => div(d^2 * D, 2), 1 => div(d^2 * D, 2)) +for Vl in (Vqro, Vqro'), Vr in (Vqro, Vqro') # random positive-definite environment - Z = randn(Float64, Vext ← Vbond) + Vbond = Vl ⊗ Vr + Dext = dim(Vbond) + Vext = Vect[FermionParity](0 => div(Dext, 2) + 1, 1 => div(Dext, 2) + 1) + Z = randn(elt, Vext ← Vbond) + normalize!(Z, Inf) benv = Z' * Z - normalize!(benv, Inf) - # untruncated bond tensor - a2b2 = randn(Float64, Vbondl ⊗ Vbondr ← Vphy' ⊗ Vphy') - a2, s, b2 = svd_compact(permute(a2b2, perm_ab)) - a2, b2 = PEPSKit.absorb_s(a2, s, b2) + @info "Dimension of benv = $(Dext)" + # untruncated bond tensors + a2 = randn(elt, Vl ⊗ Vphy ← Vint) + b2 = randn(elt, Vint ← Vphy' ⊗ Vr') # bond tensor (truncated SVD initialization) - a0, s, b0 = svd_trunc(permute(a2b2, perm_ab); trunc = trunc) + a2b2 = _combine_ket(a2, b2) + a0, s, b0 = svd_trunc(permute(a2b2, ((1, 3), (4, 2))); trunc = trunc) a0, b0 = PEPSKit.absorb_s(a0, s, b0) - fid0 = cost_function_als(benv, PEPSKit._combine_ab(a0, b0), a2b2)[2] + fid0 = cost_function_als(benv, _combine_ket(a0, b0), a2b2)[2] @info "Fidelity of simple SVD truncation = $fid0.\n" ss = Dict{String, DiagonalTensorMap}() + # FET is slower when d is large for (label, alg) in ( ("ALS", ALSTruncation(; trunc, maxiter, check_interval)), ("FET", FullEnvTruncation(; trunc, maxiter, check_interval, trunc_init = false)), ) - a1, ss[label], b1, info = PEPSKit.bond_truncate(a2, b2, benv, alg) + a1, ss[label], b1, info = bond_truncate(a2, b2, benv, alg) @info "$label improved fidelity = $(info.fid)." # display(ss[label]) - @test info.fid ≈ cost_function_als(benv, PEPSKit._combine_ab(a1, b1), a2b2)[2] + @test info.fid ≈ cost_function_als(benv, _combine_ket(a1, b1), a2b2)[2] @test info.fid > fid0 end @test isapprox(ss["ALS"], ss["FET"], atol = 1.0e-3)