From 6a0fa694b95a41aca97393049e4e24eacfdcfd64 Mon Sep 17 00:00:00 2001 From: Lukas <37111893+lkdvos@users.noreply.github.com> Date: Fri, 29 Sep 2023 17:33:13 +0200 Subject: [PATCH] Automatic Differentiation (#82) * Setup ChainRules package extension * Port over some methods from TensorKitAD * Updates for leftorth and rightorth rrules * Formatting * Fix missing using PackageExtensionCompat * Add some missing `rrule`s. * little bit of cleanup * repartition * Remove overloaded `rrule`s in favor of TensorOperations update * Update VectorInterface 0.4 * Various AD bugfixes * some AD tests amd updates * Update AD rules, clean up tests * Include qdim in vectors * Formatter * remove `thunk` in trace * Formatter * Add rrule for efficient copy constructor * Fix non-existent argument name * Add ProjectTo functionality * Add tensoroperations tests * Changes for TensorOperations 4.0.6 tensorscalar now has a `rrule` * remove obsolete test --------- Co-authored-by: leburgel --- Project.toml | 14 +- ext/TensorKitChainRulesCoreExt.jl | 340 ++++++++++++++++++++++++++++++ src/TensorKit.jl | 8 + src/tensors/braidingtensor.jl | 14 +- src/tensors/tensoroperations.jl | 10 +- test/ad.jl | 253 ++++++++++++++++++++++ test/runtests.jl | 1 + 7 files changed, 627 insertions(+), 13 deletions(-) create mode 100644 ext/TensorKitChainRulesCoreExt.jl create mode 100644 test/ad.jl diff --git a/Project.toml b/Project.toml index 1850e74c..e4c2cffc 100644 --- a/Project.toml +++ b/Project.toml @@ -7,24 +7,34 @@ version = "0.11.2" HalfIntegers = "f0d1745a-41c9-11e9-1dd9-e5d34d218721" LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8" WignerSymbols = "9f57e263-0b3d-5e2e-b1be-24f2bb48858b" +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + +[extensions] +TensorKitChainRulesCoreExt = "ChainRulesCore" + [compat] HalfIntegers = "1" LRUCache = "1.0.2" Strided = "2" -TensorOperations = "4.0.5" +TensorOperations = "4.0.6" TupleTools = "1.1" VectorInterface = "0.4" WignerSymbols = "1,2" julia = "1.6" [extras] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" HalfIntegers = "f0d1745a-41c9-11e9-1dd9-e5d34d218721" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -34,4 +44,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" WignerSymbols = "9f57e263-0b3d-5e2e-b1be-24f2bb48858b" [targets] -test = ["Combinatorics", "HalfIntegers", "LinearAlgebra", "Random", "TensorOperations", "Test", "TestExtras", "WignerSymbols"] +test = ["Combinatorics", "HalfIntegers", "LinearAlgebra", "Random", "TensorOperations", "Test", "TestExtras", "WignerSymbols", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences"] diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl new file mode 100644 index 00000000..fac03e91 --- /dev/null +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -0,0 +1,340 @@ +module TensorKitChainRulesCoreExt + +using TensorOperations +using TensorKit +using ChainRulesCore +using LinearAlgebra +using TupleTools + +# Utility +# ------- + +_conj(conjA::Symbol) = conjA == :C ? :N : :C +trivtuple(N) = ntuple(identity, N) + +function _repartition(p::IndexTuple, N₁::Int) + length(p) >= N₁ || + throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) + return p[1:N₁], p[(N₁ + 1):end] +end +_repartition(p::Index2Tuple, N₁::Int) = _repartition(linearize(p), N₁) +function _repartition(p::Union{IndexTuple,Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} + return _repartition(p, N₁) +end +function _repartition(p::Union{IndexTuple,Index2Tuple}, + ::AbstractTensorMap{<:Any,N₁}) where {N₁} + return _repartition(p, N₁) +end + +TensorKit.block(t::ZeroTangent, c::Sector) = t + +# Constructors +# ------------ + +@non_differentiable TensorKit.TensorMap(f::Function, storagetype, cod, dom) +@non_differentiable TensorKit.isomorphism(args...) +@non_differentiable TensorKit.isometry(args...) +@non_differentiable TensorKit.unitary(args...) + +function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...) + function TensorMap_pullback(Δt) + ∂d = convert(Array, Δt) + return NoTangent(), ∂d, fill(NoTangent(), length(args))... + end + return TensorMap(d, args...), TensorMap_pullback +end + +function ChainRulesCore.rrule(::typeof(convert), T::Type{<:Array}, t::AbstractTensorMap) + A = convert(T, t) + function convert_pullback(ΔA) + ∂t = TensorMap(ΔA, codomain(t), domain(t)) + return NoTangent(), NoTangent(), ∂t + end + return A, convert_pullback +end + +function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap) + copy_pullback(Δt) = NoTangent(), Δt + return copy(t), copy_pullback +end + +ChainRulesCore.ProjectTo(::T) where {T<:AbstractTensorMap} = ProjectTo{T}() +(::ProjectTo{T})(x::AbstractTensorMap) where {T<:AbstractTensorMap} = convert(T, x) + +# Base Linear Algebra +# ------------------- + +function ChainRulesCore.rrule(::typeof(+), a::AbstractTensorMap, b::AbstractTensorMap) + plus_pullback(Δc) = NoTangent(), Δc, Δc + return a + b, plus_pullback +end + +function ChainRulesCore.rrule(::typeof(-), a::AbstractTensorMap, b::AbstractTensorMap) + minus_pullback(Δc) = NoTangent(), Δc, -Δc + return a - b, minus_pullback +end + +function ChainRulesCore.rrule(::typeof(*), a::AbstractTensorMap, b::AbstractTensorMap) + times_pullback(Δc) = NoTangent(), @thunk(Δc * b'), @thunk(a' * Δc) + return a * b, times_pullback +end + +function ChainRulesCore.rrule(::typeof(*), a::AbstractTensorMap, b::Number) + times_pullback(Δc) = NoTangent(), @thunk(Δc * b'), @thunk(dot(a, Δc)) + return a * b, times_pullback +end + +function ChainRulesCore.rrule(::typeof(*), a::Number, b::AbstractTensorMap) + times_pullback(Δc) = NoTangent(), @thunk(dot(b, Δc)), @thunk(a' * Δc) + return a * b, times_pullback +end + +function ChainRulesCore.rrule(::typeof(permute), tsrc::AbstractTensorMap, p::Index2Tuple) + function permute_pullback(Δtdst) + invp = TensorKit._canonicalize(TupleTools.invperm(linearize(p)), tsrc) + return NoTangent(), permute(unthunk(Δtdst), invp), NoTangent() + end + return permute(tsrc, p), permute_pullback +end + +# LinearAlgebra +# ------------- + +function ChainRulesCore.rrule(::typeof(tr), A::AbstractTensorMap) + tr_pullback(Δtr) = NoTangent(), Δtr * id(domain(A)) + return tr(A), tr_pullback +end + +function ChainRulesCore.rrule(::typeof(adjoint), A::AbstractTensorMap) + adjoint_pullback(Δadjoint) = NoTangent(), adjoint(unthunk(Δadjoint)) + return adjoint(A), adjoint_pullback +end + +function ChainRulesCore.rrule(::typeof(dot), a::AbstractTensorMap, b::AbstractTensorMap) + dot_pullback(Δd) = NoTangent(), @thunk(b * Δd'), @thunk(a * Δd) + return dot(a, b), dot_pullback +end + +function ChainRulesCore.rrule(::typeof(norm), a::AbstractTensorMap, p) + p == 2 || error("currently only implemented for p = 2") + n = norm(a, p) + norm_pullback(Δn) = NoTangent(), a * (Δn' + Δn) / (n * 2), NoTangent() + return n, norm_pullback +end + +# Factorizations +# -------------- + +function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap; kwargs...) + U, S, V, ϵ = tsvd(t; kwargs...) + + function tsvd!_pullback((ΔU, ΔS, ΔV, Δϵ)) + ∂t = similar(t) + for (c, b) in blocks(∂t) + copyto!(b, + svd_rev(block(U, c), block(S, c), block(V, c), + block(ΔU, c), block(ΔS, c), block(ΔV, c))) + end + + return NoTangent(), ∂t + end + + return (U, S, V, ϵ), tsvd!_pullback +end + +""" + svd_rev(U, S, V, ΔU, ΔS, ΔV; tol=eps(real(scalartype(Σ)))^(4 / 5)) + +Implements the following back propagation formula for the SVD: + +```math +ΔA = UΔSV' + U(J + J')SV' + US(K + K')V' + \\frac{1}{2}US^{-1}(L' - L)V'\\ +J = F ∘ (U'ΔU), \\qquad K = F ∘ (V'ΔV), \\qquad L = I ∘ (V'ΔV)\\ +F_{i ≠ j} = \\frac{1}{s_j^2 - s_i^2}\\ +F_{ii} = 0 +``` + +# References + +Wan, Zhou-Quan, and Shi-Xin Zhang. 2019. “Automatic Differentiation for Complex Valued SVD.” https://doi.org/10.48550/ARXIV.1909.02659. +""" +function svd_rev(U::AbstractMatrix, S::AbstractMatrix, V::AbstractMatrix, ΔU, ΔS, ΔV; + atol::Real=0, + rtol::Real=atol > 0 ? 0 : eps(scalartype(S))^(3 / 4)) + # project out gauge invariance dependence? + # ΔU * U + ΔV * V' = 0 + + tol = atol > 0 ? atol : rtol * S[1, 1] + F = _invert_S²(S, tol) + S⁻¹ = pinv(S; atol=tol) + + term = Diagonal(diag(ΔS)) + + J = F .* (U' * ΔU) + term += (J + J') * S + VΔV = (V * ΔV') + K = F .* VΔV + term += S * (K + K') + + if scalartype(U) <: Complex && !(ΔV isa ZeroTangent) && !(ΔU isa ZeroTangent) + L = LinearAlgebra.Diagonal(diag(VΔV)) + term += 0.5 * S⁻¹ * (L' - L) + end + + ΔA = U * term * V + + if size(U, 1) != size(V, 2) + UUd = U * U' + VdV = V' * V + ΔA += (one(UUd) - UUd) * ΔU * S⁻¹ * V + U * S⁻¹ * ΔV * (one(VdV) - VdV) + end + + return ΔA +end + +function _invert_S²(S::AbstractMatrix{T}, tol::Real) where {T<:Real} + F = similar(S) + @inbounds for i in axes(F, 1), j in axes(F, 2) + F[i, j] = if i == j + zero(T) + else + sᵢ, sⱼ = S[i, i], S[j, j] + 1 / (abs(sⱼ - sᵢ) < tol ? tol : sⱼ^2 - sᵢ^2) + end + end + return F +end + +function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos()) + alg isa TensorKit.QR || alg isa TensorKit.QRpos || error("only QR and QRpos supported") + Q, R = leftorth(t; alg) + leftorth!_pullback((ΔQ, ΔR)) = NoTangent(), qr_pullback!(similar(t), t, Q, R, ΔQ, ΔR) + leftorth!_pullback(::Tuple{ZeroTangent,ZeroTangent}) = ZeroTangent() + return (Q, R), leftorth!_pullback +end + +function ChainRulesCore.rrule(::typeof(rightorth!), t::AbstractTensorMap; alg=LQpos()) + alg isa TensorKit.LQ || alg isa TensorKit.LQpos || error("only LQ and LQpos supported") + L, Q = rightorth(t; alg) + rightorth!_pullback((ΔL, ΔQ)) = NoTangent(), lq_pullback!(similar(t), t, L, Q, ΔL, ΔQ) + rightorth!_pullback(::Tuple{ZeroTangent,ZeroTangent}) = ZeroTangent() + return (L, Q), rightorth!_pullback +end + +""" + copyltu!(A::AbstractMatrix) + +Copy the conjugated lower triangular part of `A` to the upper triangular part. +""" +function copyltu!(A::AbstractMatrix) + m, n = size(A) + for i in axes(A, 1) + A[i, i] = real(A[i, i]) + @inbounds for j in (i + 1):n + A[i, j] = conj(A[j, i]) + end + end + return A +end + +function qr_pullback!(ΔA::AbstractTensorMap{S}, t::AbstractTensorMap{S}, + Q::AbstractTensorMap{S}, R::AbstractTensorMap{S}, ΔQ, ΔR) where {S} + for (c, b) in blocks(ΔA) + qr_pullback!(b, block(t, c), block(Q, c), block(R, c), block(ΔQ, c), block(ΔR, c)) + end + return ΔA +end + +function qr_pullback!(ΔA, A, Q::M, R::M, ΔQ, ΔR) where {M<:AbstractMatrix} + m = qr_rank(R) + n = size(R, 2) + + if n == m # full rank + q = view(Q, :, 1:m) + Δq = view(ΔQ, :, 1:m) + r = view(R, 1:m, :) + Δr = view(ΔR, 1:m, :) + ΔA = qr_pullback_fullrank!(ΔA, q, r, Δq, Δr) + else + q = view(Q, :, 1:m) + Δq = view(ΔQ, :, 1:m) + view(A, :, (m + 1):n) * view(ΔR, :, (m + 1):n)' + r = view(R, 1:m, 1:m) + Δr = view(ΔR, 1:m, 1:m) + + qr_pullback_fullrank!(view(ΔA, :, 1:m), q, r, Δq, Δr) + ΔA[:, (m + 1):n] = q * view(ΔR, :, (m + 1):n) + end + + return ΔA +end + +function qr_pullback_fullrank!(ΔA, Q, R, ΔQ, ΔR) + b = ΔQ + Q * copyltu!(R * ΔR' - ΔQ' * Q) + return adjoint!(ΔA, LinearAlgebra.LAPACK.trtrs!('U', 'N', 'N', R, copy(adjoint(b)))) +end + +function lq_pullback!(ΔA::AbstractTensorMap{S}, t::AbstractTensorMap{S}, + L::AbstractTensorMap{S}, Q::AbstractTensorMap{S}, ΔL, ΔQ) where {S} + for (c, b) in blocks(ΔA) + lq_pullback!(b, block(t, c), block(L, c), block(Q, c), block(ΔL, c), block(ΔQ, c)) + end + return ΔA +end + +function lq_pullback!(ΔA, A, L::M, Q::M, ΔL, ΔQ) where {M<:AbstractMatrix} + m = qr_rank(L) + n = size(L, 1) + + if n == m # full rank + l = view(L, :, 1:m) + Δl = view(ΔL, :, 1:m) + q = view(Q, 1:m, :) + Δq = view(ΔQ, 1:m, :) + ΔA = lq_pullback_fullrank!(ΔA, l, q, Δl, Δq) + else + l = view(L, 1:m, 1:m) + Δl = view(ΔL, 1:m, 1:m) + q = view(Q, 1:m, :) + Δq = view(ΔQ, 1:m, :) + view(ΔL, (m + 1):n, 1:m)' * view(A, (m + 1):n, :) + + lq_pullback_fullrank!(view(ΔA, 1:m, :), l, q, Δl, Δq) + ΔA[(m + 1):n, :] = view(ΔL, (m + 1):n, :) * q + end + + return ΔA +end + +function lq_pullback_fullrank!(ΔA, L, Q, ΔL, ΔQ) + mul!(ΔA, copyltu!(L' * ΔL - ΔQ * Q'), Q) + axpy!(true, ΔQ, ΔA) + return LinearAlgebra.LAPACK.trtrs!('L', 'C', 'N', L, ΔA) +end + +function qr_rank(r::AbstractMatrix) + Base.require_one_based_indexing(r) + m, n = size(r) + r₀ = r[1, 1] + i = findfirst(x -> abs(x / r₀) < 1e-12, diag(r)) + return isnothing(i) ? min(m, n) : i - 1 +end + +function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{Dict}, t::AbstractTensorMap) + out = convert(Dict, t) + function convert_pullback(c) + if haskey(c, :data) # :data is the only thing for which this dual makes sense + dual = copy(out) + dual[:data] = c[:data] + return (NoTangent(), NoTangent(), convert(TensorMap, dual)) + else + # instead of zero(t) you can also return ZeroTangent(), which is type unstable + return (NoTangent(), NoTangent(), zero(t)) + end + end + return out, convert_pullback +end +function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{TensorMap}, + t::Dict{Symbol,Any}) + return convert(TensorMap, t), v -> (NoTangent(), NoTangent(), convert(Dict, v)) +end + +end diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 0a0bdbeb..1d26db06 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -116,6 +116,8 @@ using LinearAlgebra: norm, dot, normalize, normalize!, tr, Diagonal, Hermitian import Base.Meta +using PackageExtensionCompat + # Auxiliary files #----------------- include("auxiliary/auxiliary.jl") @@ -202,4 +204,10 @@ include("planar/planaroperations.jl") # deprecations: to be removed in version 1.0 or sooner include("auxiliary/deprecate.jl") +# Extensions +# ---------- +function __init__() + @require_extensions +end + end diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index 10a29a07..fb21289c 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -193,19 +193,19 @@ function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, codB, domB = codomainind(B), domainind(B) oindA, cindA, oindB, cindB = reorder_indices(codA, domA, codB, domB, oindA, cindA, oindB, cindB, p1, p2) - + if space(B, cindB[1]) != space(A, cindA[1])' || space(B, cindB[2]) != space(A, cindA[2])' throw(SpaceMismatch("$(space(C)) ≠ permute($(space(A))[$oindA, $cindA] * $(space(B))[$cindB, $oindB], ($p1, $p2)")) end - + if BraidingStyle(sectortype(B)) isa Bosonic return add_permute!(C, B, (reverse(cindB), oindB), α, β, backend...) end τ_levels = A.adjoint ? (1, 2, 2, 1) : (2, 1, 1, 2) scale!(C, β) - + inv_braid = τ_levels[cindA[1]] > τ_levels[cindA[2]] for (f₁, f₂) in fusiontrees(B) local newtrees @@ -221,7 +221,8 @@ function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, end end for ((f₁′, f₂′), coeff) in newtrees - TO.tensoradd!(C[f₁′, f₂′], (reverse(cindB), oindB), B[f₁, f₂], :N, α * coeff, One(), backend...) + TO.tensoradd!(C[f₁′, f₂′], (reverse(cindB), oindB), B[f₁, f₂], :N, α * coeff, + One(), backend...) end end return C @@ -251,7 +252,7 @@ function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, scale!(C, β) τ_levels = B.adjoint ? (1, 2, 2, 1) : (2, 1, 1, 2) inv_braid = τ_levels[cindB[1]] > τ_levels[cindB[2]] - + for (f₁, f₂) in fusiontrees(A) local newtrees for ((f₁′, f₂′), coeff′) in transpose(f₁, f₂, oindA, cindA) @@ -266,7 +267,8 @@ function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, end end for ((f₁′, f₂′), coeff) in newtrees - TO.tensoradd!(C[f₁′, f₂′], (oindA, reverse(cindA)), A[f₁, f₂], :N, α * coeff, One(), backend...) + TO.tensoradd!(C[f₁′, f₂′], (oindA, reverse(cindA)), A[f₁, f₂], :N, α * coeff, + One(), backend...) end end return C diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index b7fa6cb7..9f8ed5a9 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -26,10 +26,10 @@ function _canonicalize(p::Index2Tuple{N₁,N₂}, ::AbstractTensorMap{<:IndexSpace,N₁,N₂}) where {N₁,N₂} return p end -function _canonicalize(p::Index2Tuple, t::AbstractTensorMap) - p′ = linearize(p) - p₁ = TupleTools.getindices(p′, codomainind(t)) - p₂ = TupleTools.getindices(p′, domainind(t)) +_canonicalize(p::Index2Tuple, t::AbstractTensorMap) = _canonicalize(linearize(p), t) +function _canonicalize(p::IndexTuple, t::AbstractTensorMap) + p₁ = TupleTools.getindices(p, codomainind(t)) + p₂ = TupleTools.getindices(p, domainind(t)) return (p₁, p₂) end @@ -42,7 +42,7 @@ function TO.tensoradd!(C::AbstractTensorMap{S}, pC::Index2Tuple, pC′ = _canonicalize(pC, C) elseif conjA == :C A′ = adjoint(A) - pC′ = adjointtensorindices(A, _canonicalize(pA, C)) + pC′ = adjointtensorindices(A, _canonicalize(pC, C)) else throw(ArgumentError("unknown conjugation flag $conjA")) end diff --git a/test/ad.jl b/test/ad.jl new file mode 100644 index 00000000..3f87afca --- /dev/null +++ b/test/ad.jl @@ -0,0 +1,253 @@ +using ChainRulesCore +using ChainRulesTestUtils +using Random +using FiniteDifferences +using LinearAlgebra + +## Test utility +# ------------- +function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::AbstractTensorMap) + return TensorMap(randn, scalartype(x), space(x)) +end +function ChainRulesTestUtils.test_approx(actual::AbstractTensorMap, + expected::AbstractTensorMap, msg=""; kwargs...) + for (c, b) in blocks(actual) + ChainRulesTestUtils.@test_msg msg isapprox(block(expected, c), b; kwargs...) + end +end +function FiniteDifferences.to_vec(t::T) where {T<:TensorKit.TrivialTensorMap} + vec, from_vec = to_vec(t.data) + return vec, x -> T(from_vec(x), codomain(t), domain(t)) +end +function FiniteDifferences.to_vec(t::AbstractTensorMap) + vec = mapreduce(vcat, blocks(t)) do (c, b) + if scalartype(t) <: Real + return reshape(b, :) .* sqrt(dim(c)) + else + v = reshape(b, :) .* sqrt(dim(c)) + return vcat(real(v), imag(v)) + end + end + + function from_vec(x) + t′ = similar(t) + T = scalartype(t) + ctr = 0 + for (c, b) in blocks(t′) + n = length(b) + if T <: Real + copyto!(b, reshape(x[(ctr + 1):(ctr + n)], size(b)) ./ sqrt(dim(c))) + else + v = x[(ctr + 1):(ctr + 2n)] + copyto!(b, + complex.(x[(ctr + 1):(ctr + n)], x[(ctr + n + 1):(ctr + 2n)]) ./ + sqrt(dim(c))) + end + ctr += T <: Real ? n : 2n + end + return t′ + end + + return vec, from_vec +end +FiniteDifferences.to_vec(t::TensorKit.AdjointTensorMap) = to_vec(copy(t)) + +function _randomize!(a::TensorMap) + for b in values(blocks(a)) + copyto!(b, randn(size(b))) + end + return a +end + +# Float32 and finite differences don't mix well +precision(::Type{<:Union{Float32,Complex{Float32}}}) = 1e-2 +precision(::Type{<:Union{Float64,Complex{Float64}}}) = 1e-8 + +# rrules for functions that destroy inputs +# ---------------------------------------- +function ChainRulesCore.rrule(::typeof(TensorKit.tsvd), args...) + return ChainRulesCore.rrule(tsvd!, args...) +end +function ChainRulesCore.rrule(::typeof(TensorKit.leftorth), args...; kwargs...) + return ChainRulesCore.rrule(leftorth!, args...; kwargs...) +end +function ChainRulesCore.rrule(::typeof(TensorKit.rightorth), args...; kwargs...) + return ChainRulesCore.rrule(rightorth!, args...; kwargs...) +end + +# complex-valued svd? +# ------------------- + +# function _gaugefix!(U, V) +# s = LinearAlgebra.Diagonal(TensorKit._safesign.(diag(U))) +# rmul!(U, s) +# lmul!(s', V) +# return U, V +# end + +# function _tsvd(t::AbstractTensorMap) +# U, S, V, ϵ = tsvd(t) +# for (c, b) in blocks(U) +# _gaugefix!(b, block(V, c)) +# end +# return U, S, V, ϵ +# end + +# svd_rev = Base.get_extension(TensorKit, :TensorKitChainRulesCoreExt).svd_rev + +# function ChainRulesCore.rrule(::typeof(_tsvd), t::AbstractTensorMap) +# U, S, V, ϵ = _tsvd(t) +# function _tsvd_pullback((ΔU, ΔS, ΔV, Δϵ)) +# ∂t = similar(t) +# for (c, b) in blocks(∂t) +# copyto!(b, +# svd_rev(block(U, c), block(S, c), block(V, c), +# block(ΔU, c), block(ΔS, c), block(ΔV, c))) +# end +# return NoTangent(), ∂t +# end +# return (U, S, V, ϵ), _tsvd_pullback +# end + +# Tests +# ----- + +ChainRulesTestUtils.test_method_tables() + +Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + (ℂ[Z2Irrep](0 => 1, 1 => 1), + ℂ[Z2Irrep](0 => 1, 1 => 2)', + ℂ[Z2Irrep](0 => 3, 1 => 2)', + ℂ[Z2Irrep](0 => 2, 1 => 3), + ℂ[Z2Irrep](0 => 2, 1 => 2)), + (ℂ[U1Irrep](0 => 1, 1 => 2, -1 => 2), + ℂ[U1Irrep](0 => 3, 1 => 1, -1 => 1), + ℂ[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + ℂ[U1Irrep](0 => 1, 1 => 2, -1 => 2), + ℂ[U1Irrep](0 => 1, 1 => 3, -1 => 2)'), + (ℂ[SU2Irrep](0 => 3, 1 // 2 => 1), + ℂ[SU2Irrep](0 => 2, 1 => 1), + ℂ[SU2Irrep](1 // 2 => 1, 1 => 1)', + ℂ[SU2Irrep](0 => 2, 1 // 2 => 2), + ℂ[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)')) + +@testset "Automatic Differentiation ($(eltype(V)))" verbose = true for V in Vlist + @testset "Basic Linear Algebra ($T)" for T in (Float64, ComplexF64) + A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + B = TensorMap(randn, T, space(A)) + + test_rrule(+, A, B) + test_rrule(-, A, B) + + α = randn(T) + test_rrule(*, α, A) + test_rrule(*, A, α) + + C = TensorMap(randn, T, domain(A), codomain(A)) + test_rrule(*, A, C) + + test_rrule(permute, A, ((1, 3, 2), (5, 4))) + end + + @testset "Linear Algebra part II ($T)" for T in (Float64, ComplexF64) + for i in 1:3 + E = TensorMap(randn, T, ⊗(V[1:i]...) ← ⊗(V[1:i]...)) + test_rrule(LinearAlgebra.tr, E) + end + + A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + test_rrule(LinearAlgebra.adjoint, A) + test_rrule(LinearAlgebra.norm, A, 2) + end + + @testset "TensorOperations ($T)" for T in (Float64, ComplexF64) + atol = precision(T) + rtol = precision(T) + + @testset "tensortrace!" begin + A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[1] ⊗ V[5]) + pC = ((3, 5), (2,)) + pA = ((1,), (4,)) + α = randn(T) + β = randn(T) + + C = _randomize!(TensorOperations.tensoralloc_add(T, pC, A, :N, false)) + test_rrule(tensortrace!, C, pC, A, pA, :N, α, β; atol, rtol) + + C = _randomize!(TensorOperations.tensoralloc_add(T, pC, A, :C, false)) + test_rrule(tensortrace!, C, pC, A, pA, :C, α, β; atol, rtol) + end + + @testset "tensoradd!" begin + p = ((1, 3, 2), (5, 4)) + A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + C = _randomize!(TensorOperations.tensoralloc_add(T, p, A, :N, false)) + α = randn(T) + β = randn(T) + test_rrule(tensoradd!, C, p, A, :N, α, β; atol, rtol) + + C = _randomize!(TensorOperations.tensoralloc_add(T, p, A, :C, false)) + test_rrule(tensoradd!, C, p, A, :C, α, β; atol, rtol) + end + + @testset "tensorcontract!" begin + A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + B = TensorMap(randn, T, V[3] ⊗ V[1]' ← V[2]) + pC = ((3, 2), (4, 1)) + pA = ((2, 4, 5), (1, 3)) + pB = ((2, 1), (3,)) + α = randn(T) + β = randn(T) + + C = _randomize!(TensorOperations.tensoralloc_contract(T, pC, A, pA, :N, + B, pB, :N, false)) + test_rrule(tensorcontract!, C, pC, A, pA, :N, B, pB, :N, α, β; atol, rtol) + + A2 = TensorMap(randn, T, V[1]' ⊗ V[2]' ← V[3]' ⊗ V[4]' ⊗ V[5]') + C = _randomize!(TensorOperations.tensoralloc_contract(T, pC, A2, pA, :C, + B, pB, :N, false)) + test_rrule(tensorcontract!, C, pC, A2, pA, :C, B, pB, :N, α, β; atol, rtol) + + B2 = TensorMap(randn, T, V[3]' ⊗ V[1] ← V[2]') + C = _randomize!(TensorOperations.tensoralloc_contract(T, pC, A, pA, :N, + B2, pB, :C, false)) + test_rrule(tensorcontract!, C, pC, A, pA, :N, B2, pB, :C, α, β; atol, rtol) + + C = _randomize!(TensorOperations.tensoralloc_contract(T, pC, A2, pA, :C, + B2, pB, :C, false)) + test_rrule(tensorcontract!, C, pC, A2, pA, :C, B2, pB, :C, α, β; atol, rtol) + end + + @testset "tensorscalar" begin + A = Tensor(randn, T, ProductSpace{typeof(V[1]),0}()) + test_rrule(tensorscalar, A) + end + end + + @testset "Factorizations ($T)" for T in (Float64, ComplexF64) + A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + B = TensorMap(randn, T, space(A)') + C = TensorMap(randn, T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) + atol = 1e-6 + + for alg in (TensorKit.QR(), TensorKit.QRpos()) + test_rrule(leftorth, A; fkwargs=(; alg=alg), atol) + test_rrule(leftorth, B; fkwargs=(; alg=alg), atol) + test_rrule(leftorth, C; fkwargs=(; alg=alg), atol) + end + + for alg in (TensorKit.LQ(), TensorKit.LQpos()) + test_rrule(rightorth, A; fkwargs=(; alg=alg), atol) + test_rrule(rightorth, B; fkwargs=(; alg=alg), atol) + test_rrule(rightorth, C; fkwargs=(; alg=alg), atol) + end + + # Complex-valued SVD tests are incompatible with finite differencing, + # because U and V are not unique. + if T <: Real + test_rrule(tsvd, A; atol) + test_rrule(tsvd, B; atol) + test_rrule(tsvd, C; atol) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 4738f065..92ceb7ac 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -62,6 +62,7 @@ include("fusiontrees.jl") include("spaces.jl") include("tensors.jl") include("planar.jl") +include("ad.jl") Tf = time() printstyled("Finished all tests in ", string(round((Tf - Ti) / 60; sigdigits=3)),