From fcad6a99c7bf4870d4d4a92dddcd55bc81569ecb Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 27 Jun 2024 15:35:24 +0200 Subject: [PATCH] Refactor ChainRulesCoreExt into separate files --- .../TensorKitChainRulesCoreExt.jl | 28 ++ .../constructors.jl | 47 +++ .../factorizations.jl} | 353 +----------------- ext/TensorKitChainRulesCoreExt/linalg.jl | 92 +++++ .../tensoroperations.jl | 145 +++++++ ext/TensorKitChainRulesCoreExt/utility.jl | 29 ++ 6 files changed, 343 insertions(+), 351 deletions(-) create mode 100644 ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl create mode 100644 ext/TensorKitChainRulesCoreExt/constructors.jl rename ext/{TensorKitChainRulesCoreExt.jl => TensorKitChainRulesCoreExt/factorizations.jl} (54%) create mode 100644 ext/TensorKitChainRulesCoreExt/linalg.jl create mode 100644 ext/TensorKitChainRulesCoreExt/tensoroperations.jl create mode 100644 ext/TensorKitChainRulesCoreExt/utility.jl diff --git a/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl new file mode 100644 index 00000000..da202178 --- /dev/null +++ b/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl @@ -0,0 +1,28 @@ +module TensorKitChainRulesCoreExt + +using TensorOperations +using VectorInterface +using TensorKit +using ChainRulesCore +using LinearAlgebra +using TupleTools + +import TensorOperations as TO +using TensorOperations: Backend, promote_contract +using VectorInterface: promote_scale, promote_add + +ext = @static if isdefined(Base, :get_extension) + Base.get_extension(TensorOperations, :TensorOperationsChainRulesCoreExt) +else + TensorOperations.TensorOperationsChainRulesCoreExt +end +const _conj = ext._conj +const trivtuple = ext.trivtuple + +include("utility.jl") +include("constructors.jl") +include("linalg.jl") +include("tensoroperations.jl") +include("factorizations.jl") + +end diff --git a/ext/TensorKitChainRulesCoreExt/constructors.jl b/ext/TensorKitChainRulesCoreExt/constructors.jl new file mode 100644 index 00000000..149c54c2 --- /dev/null +++ b/ext/TensorKitChainRulesCoreExt/constructors.jl @@ -0,0 +1,47 @@ +@non_differentiable TensorKit.TensorMap(f::Function, storagetype, cod, dom) +@non_differentiable TensorKit.id(args...) +@non_differentiable TensorKit.isomorphism(args...) +@non_differentiable TensorKit.isometry(args...) +@non_differentiable TensorKit.unitary(args...) + +function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...) + function TensorMap_pullback(Δt) + ∂d = convert(Array, Δt) + return NoTangent(), ∂d, fill(NoTangent(), length(args))... + end + return TensorMap(d, args...), TensorMap_pullback +end + +function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap) + copy_pullback(Δt) = NoTangent(), Δt + return copy(t), copy_pullback +end + +function ChainRulesCore.rrule(::typeof(Base.convert), T::Type{<:Array}, + t::AbstractTensorMap) + A = convert(T, t) + function convert_pullback(ΔA) + ∂t = TensorMap(ΔA, codomain(t), domain(t)) + return NoTangent(), NoTangent(), ∂t + end + return A, convert_pullback +end + +function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{Dict}, t::AbstractTensorMap) + out = convert(Dict, t) + function convert_pullback(c) + if haskey(c, :data) # :data is the only thing for which this dual makes sense + dual = copy(out) + dual[:data] = c[:data] + return (NoTangent(), NoTangent(), convert(TensorMap, dual)) + else + # instead of zero(t) you can also return ZeroTangent(), which is type unstable + return (NoTangent(), NoTangent(), zero(t)) + end + end + return out, convert_pullback +end +function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{TensorMap}, + t::Dict{Symbol,Any}) + return convert(TensorMap, t), v -> (NoTangent(), NoTangent(), convert(Dict, v)) +end diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt/factorizations.jl similarity index 54% rename from ext/TensorKitChainRulesCoreExt.jl rename to ext/TensorKitChainRulesCoreExt/factorizations.jl index 16b51df6..b254b9a8 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt/factorizations.jl @@ -1,185 +1,5 @@ -module TensorKitChainRulesCoreExt - -using TensorOperations -using VectorInterface -using TensorKit -using ChainRulesCore -using LinearAlgebra -using TupleTools - -import TensorOperations as TO -using TensorOperations: Backend, promote_contract -using VectorInterface: promote_scale, promote_add - -ext = @static if isdefined(Base, :get_extension) - Base.get_extension(TensorOperations, :TensorOperationsChainRulesCoreExt) -else - TensorOperations.TensorOperationsChainRulesCoreExt -end -const _conj = ext._conj -const trivtuple = ext.trivtuple - -# Utility -# ------- - -function _repartition(p::IndexTuple, N₁::Int) - length(p) >= N₁ || - throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) - return p[1:N₁], p[(N₁ + 1):end] -end -_repartition(p::Index2Tuple, N₁::Int) = _repartition(linearize(p), N₁) -function _repartition(p::Union{IndexTuple,Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} - return _repartition(p, N₁) -end -function _repartition(p::Union{IndexTuple,Index2Tuple}, - ::AbstractTensorMap{<:Any,N₁}) where {N₁} - return _repartition(p, N₁) -end - -TensorKit.block(t::ZeroTangent, c::Sector) = t - -# Constructors -# ------------ - -@non_differentiable TensorKit.TensorMap(f::Function, storagetype, cod, dom) -@non_differentiable TensorKit.id(args...) -@non_differentiable TensorKit.isomorphism(args...) -@non_differentiable TensorKit.isometry(args...) -@non_differentiable TensorKit.unitary(args...) - -function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...) - function TensorMap_pullback(Δt) - ∂d = convert(Array, Δt) - return NoTangent(), ∂d, fill(NoTangent(), length(args))... - end - return TensorMap(d, args...), TensorMap_pullback -end - -function ChainRulesCore.rrule(::typeof(convert), T::Type{<:Array}, t::AbstractTensorMap) - A = convert(T, t) - function convert_pullback(ΔA) - ∂t = TensorMap(ΔA, codomain(t), domain(t)) - return NoTangent(), NoTangent(), ∂t - end - return A, convert_pullback -end - -function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap) - copy_pullback(Δt) = NoTangent(), Δt - return copy(t), copy_pullback -end - -ChainRulesCore.ProjectTo(::T) where {T<:AbstractTensorMap} = ProjectTo{T}() -function (::ProjectTo{T1})(x::T2) where {S,N1,N2,T1<:AbstractTensorMap{S,N1,N2}, - T2<:AbstractTensorMap{S,N1,N2}} - T1 === T2 && return x - y = similar(x, scalartype(T1)) - for (c, b) in blocks(y) - p = ProjectTo(b) - b .= p(block(x, c)) - end - return y -end - -# Base Linear Algebra -# ------------------- - -function ChainRulesCore.rrule(::typeof(+), a::AbstractTensorMap, b::AbstractTensorMap) - plus_pullback(Δc) = NoTangent(), Δc, Δc - return a + b, plus_pullback -end - -ChainRulesCore.rrule(::typeof(-), a::AbstractTensorMap) = -a, Δc -> (NoTangent(), -Δc) -function ChainRulesCore.rrule(::typeof(-), a::AbstractTensorMap, b::AbstractTensorMap) - minus_pullback(Δc) = NoTangent(), Δc, -Δc - return a - b, minus_pullback -end - -function ChainRulesCore.rrule(::typeof(*), a::AbstractTensorMap, b::AbstractTensorMap) - times_pullback(Δc) = NoTangent(), @thunk(Δc * b'), @thunk(a' * Δc) - return a * b, times_pullback -end - -function ChainRulesCore.rrule(::typeof(*), a::AbstractTensorMap, b::Number) - times_pullback(Δc) = NoTangent(), @thunk(Δc * b'), @thunk(dot(a, Δc)) - return a * b, times_pullback -end - -function ChainRulesCore.rrule(::typeof(*), a::Number, b::AbstractTensorMap) - times_pullback(Δc) = NoTangent(), @thunk(dot(b, Δc)), @thunk(a' * Δc) - return a * b, times_pullback -end - -function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTensorMap) - C = A ⊗ B - projectA = ProjectTo(A) - projectB = ProjectTo(B) - function otimes_pullback(ΔC_) - # TODO: this rule is probably better written in terms of inner products, - # using planarcontract and adjoint tensormaps would remove the twists. - ΔC = unthunk(ΔC_) - pΔC = ((codomainind(A)..., (domainind(A) .+ numout(B))...), - ((codomainind(B) .+ numout(A))..., - (domainind(B) .+ (numin(A) + numout(A)))...)) - dA_ = @thunk begin - ipA = (codomainind(A), domainind(A)) - pB = (allind(B), ()) - dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B))) - tB = twist(B, filter(x -> isdual(space(B, x)), allind(B))) - dA = tensorcontract!(dA, ipA, ΔC, pΔC, :N, tB, pB, :C) - return projectA(dA) - end - dB_ = @thunk begin - ipB = (codomainind(B), domainind(B)) - pA = ((), allind(A)) - dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A))) - tA = twist(A, filter(x -> isdual(space(A, x)), allind(A))) - dB = tensorcontract!(dB, ipB, tA, pA, :C, ΔC, pΔC, :N) - return projectB(dB) - end - return NoTangent(), dA_, dB_ - end - return C, otimes_pullback -end - -function ChainRulesCore.rrule(::typeof(permute), tsrc::AbstractTensorMap, p::Index2Tuple; - copy::Bool=false) - function permute_pullback(Δtdst) - invp = TensorKit._canonicalize(TupleTools.invperm(linearize(p)), tsrc) - return NoTangent(), permute(unthunk(Δtdst), invp; copy=true), NoTangent() - end - return permute(tsrc, p; copy=true), permute_pullback -end - -# LinearAlgebra -# ------------- - -function ChainRulesCore.rrule(::typeof(tr), A::AbstractTensorMap) - tr_pullback(Δtr) = NoTangent(), Δtr * id(domain(A)) - return tr(A), tr_pullback -end - -function ChainRulesCore.rrule(::typeof(adjoint), A::AbstractTensorMap) - adjoint_pullback(Δadjoint) = NoTangent(), adjoint(unthunk(Δadjoint)) - return adjoint(A), adjoint_pullback -end - -function ChainRulesCore.rrule(::typeof(dot), a::AbstractTensorMap, b::AbstractTensorMap) - dot_pullback(Δd) = NoTangent(), @thunk(b * Δd'), @thunk(a * Δd) - return dot(a, b), dot_pullback -end - -function ChainRulesCore.rrule(::typeof(norm), a::AbstractTensorMap, p::Real=2) - p == 2 || error("currently only implemented for p = 2") - n = norm(a, p) - function norm_pullback(Δn) - return NoTangent(), a * (Δn' + Δn) / 2 / hypot(n, eps(one(n))), NoTangent() - end - return n, norm_pullback -end - -# Factorizations -# -------------- +# Factorizations rules +# -------------------- function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap; trunc::TensorKit.TruncationScheme=TensorKit.NoTruncation(), p::Real=2, @@ -669,172 +489,3 @@ function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, ldiv!(LowerTriangular(L11)', ΔA1) return ΔA end - -# Convert rrules -#---------------- -function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{Dict}, t::AbstractTensorMap) - out = convert(Dict, t) - function convert_pullback(c) - if haskey(c, :data) # :data is the only thing for which this dual makes sense - dual = copy(out) - dual[:data] = c[:data] - return (NoTangent(), NoTangent(), convert(TensorMap, dual)) - else - # instead of zero(t) you can also return ZeroTangent(), which is type unstable - return (NoTangent(), NoTangent(), zero(t)) - end - end - return out, convert_pullback -end -function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{TensorMap}, - t::Dict{Symbol,Any}) - return convert(TensorMap, t), v -> (NoTangent(), NoTangent(), convert(Dict, v)) -end - -function ChainRulesCore.rrule(::typeof(TO.tensorcontract!), - C::AbstractTensorMap{S}, pC::Index2Tuple, - A::AbstractTensorMap{S}, pA::Index2Tuple, conjA::Symbol, - B::AbstractTensorMap{S}, pB::Index2Tuple, conjB::Symbol, - α::Number, β::Number, - backend::Backend...) where {S} - C′ = tensorcontract!(copy(C), pC, A, pA, conjA, B, pB, conjB, α, β, backend...) - - projectA = ProjectTo(A) - projectB = ProjectTo(B) - projectC = ProjectTo(C) - projectα = ProjectTo(α) - projectβ = ProjectTo(β) - - function pullback(ΔC′) - ΔC = unthunk(ΔC′) - ipC = invperm(linearize(pC)) - pΔC = (TupleTools.getindices(ipC, trivtuple(TO.numout(pA))), - TupleTools.getindices(ipC, TO.numout(pA) .+ trivtuple(TO.numin(pB)))) - - dC = @thunk projectC(scale(ΔC, conj(β))) - dA = @thunk begin - ipA = (invperm(linearize(pA)), ()) - conjΔC = conjA == :C ? :C : :N - conjB′ = conjA == :C ? conjB : _conj(conjB) - _dA = zerovector(A, - promote_contract(scalartype(ΔC), scalartype(B), scalartype(α))) - tB = twist(B, - TupleTools.vcat(filter(x -> !isdual(space(B, x)), pB[1]), - filter(x -> isdual(space(B, x)), pB[2]))) - _dA = tensorcontract!(_dA, ipA, - ΔC, pΔC, conjΔC, - tB, reverse(pB), conjB′, - conjA == :C ? α : conj(α), Zero(), backend...) - return projectA(_dA) - end - dB = @thunk begin - ipB = (invperm(linearize(pB)), ()) - conjΔC = conjB == :C ? :C : :N - conjA′ = conjB == :C ? conjA : _conj(conjA) - _dB = zerovector(B, - promote_contract(scalartype(ΔC), scalartype(A), scalartype(α))) - tA = twist(A, - TupleTools.vcat(filter(x -> isdual(space(A, x)), pA[1]), - filter(x -> !isdual(space(A, x)), pA[2]))) - _dB = tensorcontract!(_dB, ipB, - tA, reverse(pA), conjA′, - ΔC, pΔC, conjΔC, - conjB == :C ? α : conj(α), Zero(), backend...) - return projectB(_dB) - end - dα = @thunk begin - # TODO: this result should be AB = (C′ - βC) / α as C′ = βC + αAB - AB = tensorcontract(pC, A, pA, conjA, B, pB, conjB) - return projectα(inner(AB, ΔC)) - end - dβ = @thunk projectβ(inner(C, ΔC)) - dbackend = map(x -> NoTangent(), backend) - return NoTangent(), dC, NoTangent(), - dA, NoTangent(), NoTangent(), dB, NoTangent(), NoTangent(), dα, dβ, - dbackend... - end - return C′, pullback -end - -function ChainRulesCore.rrule(::typeof(TO.tensoradd!), - C::AbstractTensorMap{S}, pC::Index2Tuple, - A::AbstractTensorMap{S}, conjA::Symbol, - α::Number, β::Number, backend::Backend...) where {S} - C′ = tensoradd!(copy(C), pC, A, conjA, α, β, backend...) - - projectA = ProjectTo(A) - projectC = ProjectTo(C) - projectα = ProjectTo(α) - projectβ = ProjectTo(β) - - function pullback(ΔC′) - ΔC = unthunk(ΔC′) - dC = @thunk projectC(scale(ΔC, conj(β))) - dA = @thunk begin - ipC = invperm(linearize(pC)) - _dA = zerovector(A, promote_add(ΔC, α)) - _dA = tensoradd!(_dA, (ipC, ()), ΔC, conjA, conjA == :N ? conj(α) : α, Zero(), - backend...) - return projectA(_dA) - end - dα = @thunk begin - # TODO: this is an inner product implemented as a contraction - # for non-symmetric tensors this might be more efficient like this, - # but for symmetric tensors an intermediate object will anyways be created - # and then it might be more efficient to use an addition and inner product - tΔC = twist(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC))) - _dα = tensorscalar(tensorcontract(((), ()), A, ((), linearize(pC)), - _conj(conjA), tΔC, - (trivtuple(TO.numind(pC)), - ()), :N, One(), backend...)) - return projectα(_dα) - end - dβ = @thunk projectβ(inner(C, ΔC)) - dbackend = map(x -> NoTangent(), backend) - return NoTangent(), dC, NoTangent(), dA, NoTangent(), dα, dβ, dbackend... - end - - return C′, pullback -end - -function ChainRulesCore.rrule(::typeof(tensortrace!), C::AbstractTensorMap{S}, - pC::Index2Tuple, A::AbstractTensorMap{S}, - pA::Index2Tuple, conjA::Symbol, α::Number, β::Number, - backend::Backend...) where {S} - C′ = tensortrace!(copy(C), pC, A, pA, conjA, α, β, backend...) - - projectA = ProjectTo(A) - projectC = ProjectTo(C) - projectα = ProjectTo(α) - projectβ = ProjectTo(β) - - function pullback(ΔC′) - ΔC = unthunk(ΔC′) - dC = @thunk projectC(scale(ΔC, conj(β))) - dA = @thunk begin - ipC = invperm((linearize(pC)..., pA[1]..., pA[2]...)) - E = one!(TO.tensoralloc_add(scalartype(A), pA, A, conjA)) - twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) - _dA = zerovector(A, promote_scale(ΔC, α)) - _dA = tensorproduct!(_dA, (ipC, ()), ΔC, - (trivtuple(TO.numind(pC)), ()), conjA, E, - ((), trivtuple(TO.numind(pA))), conjA, - conjA == :N ? conj(α) : α, Zero(), backend...) - return projectA(_dA) - end - dα = @thunk begin - # TODO: this result might be easier to compute as: - # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α - At = tensortrace(pC, A, pA, conjA) - return projectα(inner(At, ΔC)) - end - dβ = @thunk projectβ(inner(C, ΔC)) - dbackend = map(x -> NoTangent(), backend) - return NoTangent(), dC, NoTangent(), dA, NoTangent(), NoTangent(), dα, dβ, - dbackend... - end - - return C′, pullback -end - -end diff --git a/ext/TensorKitChainRulesCoreExt/linalg.jl b/ext/TensorKitChainRulesCoreExt/linalg.jl new file mode 100644 index 00000000..a7bd64fa --- /dev/null +++ b/ext/TensorKitChainRulesCoreExt/linalg.jl @@ -0,0 +1,92 @@ +# Linear Algebra chainrules +# ------------------------- +function ChainRulesCore.rrule(::typeof(+), a::AbstractTensorMap, b::AbstractTensorMap) + plus_pullback(Δc) = NoTangent(), Δc, Δc + return a + b, plus_pullback +end + +ChainRulesCore.rrule(::typeof(-), a::AbstractTensorMap) = -a, Δc -> (NoTangent(), -Δc) +function ChainRulesCore.rrule(::typeof(-), a::AbstractTensorMap, b::AbstractTensorMap) + minus_pullback(Δc) = NoTangent(), Δc, -Δc + return a - b, minus_pullback +end + +function ChainRulesCore.rrule(::typeof(*), a::AbstractTensorMap, b::AbstractTensorMap) + times_pullback(Δc) = NoTangent(), @thunk(Δc * b'), @thunk(a' * Δc) + return a * b, times_pullback +end + +function ChainRulesCore.rrule(::typeof(*), a::AbstractTensorMap, b::Number) + times_pullback(Δc) = NoTangent(), @thunk(Δc * b'), @thunk(dot(a, Δc)) + return a * b, times_pullback +end + +function ChainRulesCore.rrule(::typeof(*), a::Number, b::AbstractTensorMap) + times_pullback(Δc) = NoTangent(), @thunk(dot(b, Δc)), @thunk(a' * Δc) + return a * b, times_pullback +end + +function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTensorMap) + C = A ⊗ B + projectA = ProjectTo(A) + projectB = ProjectTo(B) + function otimes_pullback(ΔC_) + # TODO: this rule is probably better written in terms of inner products, + # using planarcontract and adjoint tensormaps would remove the twists. + ΔC = unthunk(ΔC_) + pΔC = ((codomainind(A)..., (domainind(A) .+ numout(B))...), + ((codomainind(B) .+ numout(A))..., + (domainind(B) .+ (numin(A) + numout(A)))...)) + dA_ = @thunk begin + ipA = (codomainind(A), domainind(A)) + pB = (allind(B), ()) + dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B))) + tB = twist(B, filter(x -> isdual(space(B, x)), allind(B))) + dA = tensorcontract!(dA, ipA, ΔC, pΔC, :N, tB, pB, :C) + return projectA(dA) + end + dB_ = @thunk begin + ipB = (codomainind(B), domainind(B)) + pA = ((), allind(A)) + dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A))) + tA = twist(A, filter(x -> isdual(space(A, x)), allind(A))) + dB = tensorcontract!(dB, ipB, tA, pA, :C, ΔC, pΔC, :N) + return projectB(dB) + end + return NoTangent(), dA_, dB_ + end + return C, otimes_pullback +end + +function ChainRulesCore.rrule(::typeof(permute), tsrc::AbstractTensorMap, p::Index2Tuple; + copy::Bool=false) + function permute_pullback(Δtdst) + invp = TensorKit._canonicalize(TupleTools.invperm(linearize(p)), tsrc) + return NoTangent(), permute(unthunk(Δtdst), invp; copy=true), NoTangent() + end + return permute(tsrc, p; copy=true), permute_pullback +end + +function ChainRulesCore.rrule(::typeof(tr), A::AbstractTensorMap) + tr_pullback(Δtr) = NoTangent(), Δtr * id(domain(A)) + return tr(A), tr_pullback +end + +function ChainRulesCore.rrule(::typeof(adjoint), A::AbstractTensorMap) + adjoint_pullback(Δadjoint) = NoTangent(), adjoint(unthunk(Δadjoint)) + return adjoint(A), adjoint_pullback +end + +function ChainRulesCore.rrule(::typeof(dot), a::AbstractTensorMap, b::AbstractTensorMap) + dot_pullback(Δd) = NoTangent(), @thunk(b * Δd'), @thunk(a * Δd) + return dot(a, b), dot_pullback +end + +function ChainRulesCore.rrule(::typeof(norm), a::AbstractTensorMap, p::Real=2) + p == 2 || error("currently only implemented for p = 2") + n = norm(a, p) + function norm_pullback(Δn) + return NoTangent(), a * (Δn' + Δn) / 2 / hypot(n, eps(one(n))), NoTangent() + end + return n, norm_pullback +end diff --git a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl new file mode 100644 index 00000000..72c25a5a --- /dev/null +++ b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl @@ -0,0 +1,145 @@ +function ChainRulesCore.rrule(::typeof(TO.tensorcontract!), + C::AbstractTensorMap{S}, pC::Index2Tuple, + A::AbstractTensorMap{S}, pA::Index2Tuple, conjA::Symbol, + B::AbstractTensorMap{S}, pB::Index2Tuple, conjB::Symbol, + α::Number, β::Number, + backend::Backend...) where {S} + C′ = tensorcontract!(copy(C), pC, A, pA, conjA, B, pB, conjB, α, β, backend...) + + projectA = ProjectTo(A) + projectB = ProjectTo(B) + projectC = ProjectTo(C) + projectα = ProjectTo(α) + projectβ = ProjectTo(β) + + function pullback(ΔC′) + ΔC = unthunk(ΔC′) + ipC = invperm(linearize(pC)) + pΔC = (TupleTools.getindices(ipC, trivtuple(TO.numout(pA))), + TupleTools.getindices(ipC, TO.numout(pA) .+ trivtuple(TO.numin(pB)))) + + dC = @thunk projectC(scale(ΔC, conj(β))) + dA = @thunk begin + ipA = (invperm(linearize(pA)), ()) + conjΔC = conjA == :C ? :C : :N + conjB′ = conjA == :C ? conjB : _conj(conjB) + _dA = zerovector(A, + promote_contract(scalartype(ΔC), scalartype(B), scalartype(α))) + tB = twist(B, + TupleTools.vcat(filter(x -> !isdual(space(B, x)), pB[1]), + filter(x -> isdual(space(B, x)), pB[2]))) + _dA = tensorcontract!(_dA, ipA, + ΔC, pΔC, conjΔC, + tB, reverse(pB), conjB′, + conjA == :C ? α : conj(α), Zero(), backend...) + return projectA(_dA) + end + dB = @thunk begin + ipB = (invperm(linearize(pB)), ()) + conjΔC = conjB == :C ? :C : :N + conjA′ = conjB == :C ? conjA : _conj(conjA) + _dB = zerovector(B, + promote_contract(scalartype(ΔC), scalartype(A), scalartype(α))) + tA = twist(A, + TupleTools.vcat(filter(x -> isdual(space(A, x)), pA[1]), + filter(x -> !isdual(space(A, x)), pA[2]))) + _dB = tensorcontract!(_dB, ipB, + tA, reverse(pA), conjA′, + ΔC, pΔC, conjΔC, + conjB == :C ? α : conj(α), Zero(), backend...) + return projectB(_dB) + end + dα = @thunk begin + # TODO: this result should be AB = (C′ - βC) / α as C′ = βC + αAB + AB = tensorcontract(pC, A, pA, conjA, B, pB, conjB) + return projectα(inner(AB, ΔC)) + end + dβ = @thunk projectβ(inner(C, ΔC)) + dbackend = map(x -> NoTangent(), backend) + return NoTangent(), dC, NoTangent(), + dA, NoTangent(), NoTangent(), dB, NoTangent(), NoTangent(), dα, dβ, + dbackend... + end + return C′, pullback +end + +function ChainRulesCore.rrule(::typeof(TO.tensoradd!), + C::AbstractTensorMap{S}, pC::Index2Tuple, + A::AbstractTensorMap{S}, conjA::Symbol, + α::Number, β::Number, backend::Backend...) where {S} + C′ = tensoradd!(copy(C), pC, A, conjA, α, β, backend...) + + projectA = ProjectTo(A) + projectC = ProjectTo(C) + projectα = ProjectTo(α) + projectβ = ProjectTo(β) + + function pullback(ΔC′) + ΔC = unthunk(ΔC′) + dC = @thunk projectC(scale(ΔC, conj(β))) + dA = @thunk begin + ipC = invperm(linearize(pC)) + _dA = zerovector(A, promote_add(ΔC, α)) + _dA = tensoradd!(_dA, (ipC, ()), ΔC, conjA, conjA == :N ? conj(α) : α, Zero(), + backend...) + return projectA(_dA) + end + dα = @thunk begin + # TODO: this is an inner product implemented as a contraction + # for non-symmetric tensors this might be more efficient like this, + # but for symmetric tensors an intermediate object will anyways be created + # and then it might be more efficient to use an addition and inner product + tΔC = twist(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC))) + _dα = tensorscalar(tensorcontract(((), ()), A, ((), linearize(pC)), + _conj(conjA), tΔC, + (trivtuple(TO.numind(pC)), + ()), :N, One(), backend...)) + return projectα(_dα) + end + dβ = @thunk projectβ(inner(C, ΔC)) + dbackend = map(x -> NoTangent(), backend) + return NoTangent(), dC, NoTangent(), dA, NoTangent(), dα, dβ, dbackend... + end + + return C′, pullback +end + +function ChainRulesCore.rrule(::typeof(tensortrace!), C::AbstractTensorMap{S}, + pC::Index2Tuple, A::AbstractTensorMap{S}, + pA::Index2Tuple, conjA::Symbol, α::Number, β::Number, + backend::Backend...) where {S} + C′ = tensortrace!(copy(C), pC, A, pA, conjA, α, β, backend...) + + projectA = ProjectTo(A) + projectC = ProjectTo(C) + projectα = ProjectTo(α) + projectβ = ProjectTo(β) + + function pullback(ΔC′) + ΔC = unthunk(ΔC′) + dC = @thunk projectC(scale(ΔC, conj(β))) + dA = @thunk begin + ipC = invperm((linearize(pC)..., pA[1]..., pA[2]...)) + E = one!(TO.tensoralloc_add(scalartype(A), pA, A, conjA)) + twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) + _dA = zerovector(A, promote_scale(ΔC, α)) + _dA = tensorproduct!(_dA, (ipC, ()), ΔC, + (trivtuple(TO.numind(pC)), ()), conjA, E, + ((), trivtuple(TO.numind(pA))), conjA, + conjA == :N ? conj(α) : α, Zero(), backend...) + return projectA(_dA) + end + dα = @thunk begin + # TODO: this result might be easier to compute as: + # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α + At = tensortrace(pC, A, pA, conjA) + return projectα(inner(At, ΔC)) + end + dβ = @thunk projectβ(inner(C, ΔC)) + dbackend = map(x -> NoTangent(), backend) + return NoTangent(), dC, NoTangent(), dA, NoTangent(), NoTangent(), dα, dβ, + dbackend... + end + + return C′, pullback +end diff --git a/ext/TensorKitChainRulesCoreExt/utility.jl b/ext/TensorKitChainRulesCoreExt/utility.jl new file mode 100644 index 00000000..170ed09f --- /dev/null +++ b/ext/TensorKitChainRulesCoreExt/utility.jl @@ -0,0 +1,29 @@ +# Utility +# ------- +function _repartition(p::IndexTuple, N₁::Int) + length(p) >= N₁ || + throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) + return p[1:N₁], p[(N₁ + 1):end] +end +_repartition(p::Index2Tuple, N₁::Int) = _repartition(linearize(p), N₁) +function _repartition(p::Union{IndexTuple,Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} + return _repartition(p, N₁) +end +function _repartition(p::Union{IndexTuple,Index2Tuple}, + ::AbstractTensorMap{<:Any,N₁}) where {N₁} + return _repartition(p, N₁) +end + +TensorKit.block(t::ZeroTangent, c::Sector) = t + +ChainRulesCore.ProjectTo(::T) where {T<:AbstractTensorMap} = ProjectTo{T}() +function (::ProjectTo{T1})(x::T2) where {S,N1,N2,T1<:AbstractTensorMap{S,N1,N2}, + T2<:AbstractTensorMap{S,N1,N2}} + T1 === T2 && return x + y = similar(x, scalartype(T1)) + for (c, b) in blocks(y) + p = ProjectTo(b) + b .= p(block(x, c)) + end + return y +end