Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix rrules for Fermionic symmetries #126

Merged
merged 23 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 165 additions & 11 deletions ext/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
module TensorKitChainRulesCoreExt

using TensorOperations
using VectorInterface
using TensorKit
using ChainRulesCore
using LinearAlgebra
using TupleTools

import TensorOperations as TO
using TensorOperations: Backend, promote_contract
using VectorInterface: promote_scale, promote_add

ext = @static if isdefined(Base, :get_extension)
Base.get_extension(TensorOperations, :TensorOperationsChainRulesCoreExt)
else
TensorOperations.TensorOperationsChainRulesCoreExt
end
const _conj = ext._conj
const trivtuple = ext.trivtuple

# Utility
# -------

_conj(conjA::Symbol) = conjA == :C ? :N : :C
trivtuple(N) = ntuple(identity, N)

function _repartition(p::IndexTuple, N₁::Int)
length(p) >= N₁ ||
throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)"))
Expand Down Expand Up @@ -111,19 +121,17 @@ function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTe
dA_ = @thunk begin
ipA = (codomainind(A), domainind(A))
pB = (allind(B), ())
dA = zerovector(A,
TensorOperations.promote_contract(scalartype(ΔC),
scalartype(B)))
dA = tensorcontract!(dA, ipA, ΔC, pΔC, :N, B, pB, :C)
dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B)))
tB = twist(B, filter(x -> isdual(space(B, x)), allind(B)))
dA = tensorcontract!(dA, ipA, ΔC, pΔC, :N, tB, pB, :C)
lkdvos marked this conversation as resolved.
Show resolved Hide resolved
return projectA(dA)
end
dB_ = @thunk begin
ipB = (codomainind(B), domainind(B))
pA = ((), allind(A))
dB = zerovector(B,
TensorOperations.promote_contract(scalartype(ΔC),
scalartype(A)))
dB = tensorcontract!(dB, ipB, A, pA, :C, ΔC, pΔC, :N)
dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A)))
tA = twist(A, filter(x -> isdual(space(A, x)), allind(A)))
dB = tensorcontract!(dB, ipB, tA, pA, :C, ΔC, pΔC, :N)
return projectB(dB)
end
return NoTangent(), dA_, dB_
Expand Down Expand Up @@ -653,4 +661,150 @@ function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{TensorMap},
return convert(TensorMap, t), v -> (NoTangent(), NoTangent(), convert(Dict, v))
end

function ChainRulesCore.rrule(::typeof(TO.tensorcontract!),
C::AbstractTensorMap{S}, pC::Index2Tuple,
A::AbstractTensorMap{S}, pA::Index2Tuple, conjA::Symbol,
B::AbstractTensorMap{S}, pB::Index2Tuple, conjB::Symbol,
α::Number, β::Number,
backend::Backend...) where {S}
C′ = tensorcontract!(copy(C), pC, A, pA, conjA, B, pB, conjB, α, β, backend...)

projectA = ProjectTo(A)
projectB = ProjectTo(B)
projectC = ProjectTo(C)
projectα = ProjectTo(α)
projectβ = ProjectTo(β)

function pullback(ΔC′)
ΔC = unthunk(ΔC′)
ipC = invperm(linearize(pC))
pΔC = (TupleTools.getindices(ipC, trivtuple(TO.numout(pA))),
TupleTools.getindices(ipC, TO.numout(pA) .+ trivtuple(TO.numin(pB))))

dC = @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk begin
ipA = (invperm(linearize(pA)), ())
conjΔC = conjA == :C ? :C : :N
conjB′ = conjA == :C ? conjB : _conj(conjB)
_dA = zerovector(A,
promote_contract(scalartype(ΔC), scalartype(B), scalartype(α)))
tB = twist(B,
TupleTools.vcat(filter(x -> !isdual(space(B, x)), pB[1]),
filter(x -> isdual(space(B, x)), pB[2])))
_dA = tensorcontract!(_dA, ipA,
ΔC, pΔC, conjΔC,
tB, reverse(pB), conjB′,
conjA == :C ? α : conj(α), Zero(), backend...)
return projectA(_dA)
end
dB = @thunk begin
ipB = (invperm(linearize(pB)), ())
conjΔC = conjB == :C ? :C : :N
conjA′ = conjB == :C ? conjA : _conj(conjA)
_dB = zerovector(B,
promote_contract(scalartype(ΔC), scalartype(A), scalartype(α)))
tA = twist(A,
TupleTools.vcat(filter(x -> isdual(space(A, x)), pA[1]),
filter(x -> !isdual(space(A, x)), pA[2])))
_dB = tensorcontract!(_dB, ipB,
tA, reverse(pA), conjA′,
ΔC, pΔC, conjΔC,
conjB == :C ? α : conj(α), Zero(), backend...)
return projectB(_dB)
end
dα = @thunk begin
# TODO: this result should be AB = (C′ - βC) / α as C′ = βC + αAB
AB = tensorcontract(pC, A, pA, conjA, B, pB, conjB)
return projectα(inner(AB, ΔC))
end
dβ = @thunk projectβ(inner(C, ΔC))
dbackend = map(x -> NoTangent(), backend)
return NoTangent(), dC, NoTangent(),
dA, NoTangent(), NoTangent(), dB, NoTangent(), NoTangent(), dα, dβ,
dbackend...
end
return C′, pullback
end

function ChainRulesCore.rrule(::typeof(TO.tensoradd!),
C::AbstractTensorMap{S}, pC::Index2Tuple,
A::AbstractTensorMap{S}, conjA::Symbol,
α::Number, β::Number, backend::Backend...) where {S}
C′ = tensoradd!(copy(C), pC, A, conjA, α, β, backend...)

projectA = ProjectTo(A)
projectC = ProjectTo(C)
projectα = ProjectTo(α)
projectβ = ProjectTo(β)

function pullback(ΔC′)
ΔC = unthunk(ΔC′)
dC = @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk begin
ipC = invperm(linearize(pC))
_dA = zerovector(A, promote_add(ΔC, α))
_dA = tensoradd!(_dA, (ipC, ()), ΔC, conjA, conjA == :N ? conj(α) : α, Zero(),
backend...)
return projectA(_dA)
end
dα = @thunk begin
# TODO: this is an inner product implemented as a contraction
# for non-symmetric tensors this might be more efficient like this,
# but for symmetric tensors an intermediate object will anyways be created
# and then it might be more efficient to use an addition and inner product
tΔC = twist(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC)))
_dα = tensorscalar(tensorcontract(((), ()), A, ((), linearize(pC)),
_conj(conjA), tΔC,
(trivtuple(TO.numind(pC)),
()), :N, One(), backend...))
return projectα(_dα)
end
dβ = @thunk projectβ(inner(C, ΔC))
dbackend = map(x -> NoTangent(), backend)
return NoTangent(), dC, NoTangent(), dA, NoTangent(), dα, dβ, dbackend...
end

return C′, pullback
end

function ChainRulesCore.rrule(::typeof(tensortrace!), C::AbstractTensorMap{S},
pC::Index2Tuple, A::AbstractTensorMap{S},
pA::Index2Tuple, conjA::Symbol, α::Number, β::Number,
backend::Backend...) where {S}
C′ = tensortrace!(copy(C), pC, A, pA, conjA, α, β, backend...)

projectA = ProjectTo(A)
projectC = ProjectTo(C)
projectα = ProjectTo(α)
projectβ = ProjectTo(β)

function pullback(ΔC′)
ΔC = unthunk(ΔC′)
dC = @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk begin
ipC = invperm((linearize(pC)..., pA[1]..., pA[2]...))
E = one!(TO.tensoralloc_add(scalartype(A), pA, A, conjA))
twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E)))
_dA = zerovector(A, promote_scale(ΔC, α))
_dA = tensorproduct!(_dA, (ipC, ()), ΔC,
(trivtuple(TO.numind(pC)), ()), conjA, E,
((), trivtuple(TO.numind(pA))), conjA,
conjA == :N ? conj(α) : α, Zero(), backend...)
return projectA(_dA)
end
dα = @thunk begin
# TODO: this result might be easier to compute as:
# C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α
At = tensortrace(pC, A, pA, conjA)
return projectα(inner(At, ΔC))
end
dβ = @thunk projectβ(inner(C, ΔC))
dbackend = map(x -> NoTangent(), backend)
return NoTangent(), dC, NoTangent(), dA, NoTangent(), NoTangent(), dα, dβ,
dbackend...
end

return C′, pullback
end

end
4 changes: 2 additions & 2 deletions src/tensors/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@

# Show
#------
function Base.summary(t::AdjointTensorMap)
return print("AdjointTensorMap(", codomain(t), " ← ", domain(t), ")")
function Base.summary(io::IO, t::AdjointTensorMap)
return print(io, "AdjointTensorMap(", codomain(t), " ← ", domain(t), ")")

Check warning on line 91 in src/tensors/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/adjoint.jl#L90-L91

Added lines #L90 - L91 were not covered by tests
end
function Base.show(io::IO, t::AdjointTensorMap{S}) where {S<:IndexSpace}
if get(io, :compact, false)
Expand Down
26 changes: 20 additions & 6 deletions src/tensors/indexmanipulations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@

# Twist
"""
twist!(t::AbstractTensorMap, i::Int; inv::Bool=false)
-> t
twist!(t::AbstractTensorMap, i::Int; inv::Bool=false) -> t
twist!(t::AbstractTensorMap, is; inv::Bool=false) -> t

Apply a twist to the `i`th index of `t`, storing the result in `t`.
If `inv=true`, use the inverse twist.
Expand All @@ -248,17 +248,31 @@
end
return t
end
function twist!(t::AbstractTensorMap, is; inv::Bool=false)
if !all(in(allind(t)), is)
msg = "Can't twist indices $is of a tensor with only $(numind(t)) indices."
throw(ArgumentError(msg))

Check warning on line 254 in src/tensors/indexmanipulations.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/indexmanipulations.jl#L253-L254

Added lines #L253 - L254 were not covered by tests
end
(BraidingStyle(sectortype(t)) == Bosonic() || isempty(is)) && return t
N₁ = numout(t)
for (f₁, f₂) in fusiontrees(t)
θ = prod(i -> i <= N₁ ? twist(f₁.uncoupled[i]) : twist(f₂.uncoupled[i - N₁]), is)
inv && (θ = θ')
rmul!(t[f₁, f₂], θ)
end
return t
end

"""
twist(t::AbstractTensorMap, i::Int; inv::Bool=false)
-> t
twist(tsrc::AbstractTensorMap, i::Int; inv::Bool=false) -> tdst
twist(tsrc::AbstractTensorMap, is; inv::Bool=false) -> tdst

Apply a twist to the `i`th index of `t` and return the result as a new tensor.
Apply a twist to the `i`th index of `tsrc` and return the result as a new tensor.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot to mention that the doc string needed an update to reflect the is argument here.

If `inv=true`, use the inverse twist.

See [`twist!`](@ref) for storing the result in place.
"""
twist(t::AbstractTensorMap, i::Int; inv::Bool=false) = twist!(copy(t), i; inv=inv)
twist(t::AbstractTensorMap, i; inv::Bool=false) = twist!(copy(t), i; inv)

# Fusing and splitting
# TODO: add functionality for easy fusing and splitting of tensor indices
Expand Down
4 changes: 2 additions & 2 deletions src/tensors/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -679,8 +679,8 @@ end

# Show
#------
function Base.summary(t::TensorMap)
return print("TensorMap(", space(t), ")")
function Base.summary(io::IO, t::TensorMap)
return print(io, "TensorMap(", space(t), ")")
end
function Base.show(io::IO, t::TensorMap{S}) where {S<:IndexSpace}
if get(io, :compact, false)
Expand Down
8 changes: 1 addition & 7 deletions src/tensors/tensoroperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,13 +285,7 @@ function _contract!(α, A::AbstractTensorMap{S}, B::AbstractTensorMap{S},
end
A′ = permute(A, (oindA, cindA); copy=copyA)
B′ = permute(B, (cindB, oindB))
if BraidingStyle(sectortype(S)) isa Fermionic
for i in domainind(A′)
if !isdual(space(A′, i))
A′ = twist!(A′, i)
end
end
end
A′ = twist!(A′, filter(i -> !isdual(space(A′, i)), domainind(A′)))
ipC = TupleTools.invperm((p₁..., p₂...))
oindAinC = TupleTools.getindices(ipC, ntuple(n -> n, N₁))
oindBinC = TupleTools.getindices(ipC, ntuple(n -> n + N₁, N₂))
Expand Down
Loading
Loading