Skip to content

Commit

Permalink
include eig and eigh rrules
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Nov 30, 2023
1 parent c1b5109 commit 6fb9d67
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 136 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ jobs:
- windows-latest
arch:
- x64
- x86
exclude:
- os: macOS-latest
arch: x86
# - x86
# exclude:
# - os: macOS-latest
# arch: x86
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
Expand Down
278 changes: 200 additions & 78 deletions ext/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ function ChainRulesCore.rrule(::typeof(dot), a::AbstractTensorMap, b::AbstractTe
return dot(a, b), dot_pullback
end

function ChainRulesCore.rrule(::typeof(norm), a::AbstractTensorMap, p)
function ChainRulesCore.rrule(::typeof(norm), a::AbstractTensorMap, p::Real=2)
p == 2 || error("currently only implemented for p = 2")
n = norm(a, p)
norm_pullback(Δn) = NoTangent(), a * (Δn' + Δn) / (n * 2), NoTangent()
Expand Down Expand Up @@ -204,11 +204,114 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap;
end
return NoTangent(), Δt
end
tsvd!_pullback(::Tuple{ZeroTangent,ZeroTangent,ZeroTangent}) = NoTangent(), ZeroTangent()
function tsvd!_pullback(::Tuple{ZeroTangent,ZeroTangent,ZeroTangent})
return NoTangent(), ZeroTangent()
end

return (U′, Σ′, V′, ϵ), tsvd!_pullback
end

function ChainRulesCore.rrule(::typeof(TensorKit.eig!), t::AbstractTensorMap; kwargs...)
D, V = eig(t; kwargs...)

function eig!_pullback((ΔD, ΔV))
Δt = similar(t)
for (c, b) in blocks(Δt)
Dc, Vc = block(D, c), block(V, c)
ΔDc, ΔVc = block(ΔD, c), block(ΔV, c)
Ddc = view(Dc, diagind(Dc))
ΔDdc = (ΔDc isa AbstractZero) ? ΔDc : view(ΔDc, diagind(ΔDc))
eig_pullback!(b, Ddc, Vc, ΔDdc, ΔVc)
end
return NoTangent(), Δt
end
function eig!_pullback(::Tuple{ZeroTangent,ZeroTangent})
return NoTangent(), ZeroTangent()
end

return (D, V), eig!_pullback
end

function ChainRulesCore.rrule(::typeof(TensorKit.eigh!), t::AbstractTensorMap; kwargs...)
D, V = eigh(t; kwargs...)

function eigh!_pullback((ΔD, ΔV))
Δt = similar(t)
for (c, b) in blocks(Δt)
Dc, Vc = block(D, c), block(V, c)
ΔDc, ΔVc = block(ΔD, c), block(ΔV, c)
Ddc = view(Dc, diagind(Dc))
ΔDdc = (ΔDc isa AbstractZero) ? ΔDc : view(ΔDc, diagind(ΔDc))
eigh_pullback!(b, Ddc, Vc, ΔDdc, ΔVc)
end
return NoTangent(), Δt
end
function eigh!_pullback(::Tuple{ZeroTangent,ZeroTangent})
return NoTangent(), ZeroTangent()
end

return (D, V), eigh!_pullback
end

function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos())
alg isa TensorKit.QR || alg isa TensorKit.QRpos ||
error("only `alg=QR()` and `alg=QRpos()` are supported")
Q, R = leftorth(t; alg)
function leftorth!_pullback((ΔQ, ΔR))
Δt = similar(t)
for (c, b) in blocks(Δt)
qr_pullback!(b, block(Q, c), block(R, c), block(ΔQ, c), block(ΔR, c))
end
return NoTangent(), Δt
end
leftorth!_pullback(::Tuple{ZeroTangent,ZeroTangent}) = NoTangent(), ZeroTangent()
return (Q, R), leftorth!_pullback
end

function ChainRulesCore.rrule(::typeof(rightorth!), t::AbstractTensorMap; alg=LQpos())
alg isa TensorKit.LQ || alg isa TensorKit.LQpos ||
error("only `alg=LQ()` and `alg=LQpos()` are supported")
L, Q = rightorth(t; alg)
function rightorth!_pullback((ΔL, ΔQ))
Δt = similar(t)
for (c, b) in blocks(Δt)
lq_pullback!(b, block(L, c), block(Q, c), block(ΔL, c), block(ΔQ, c))
end
return NoTangent(), Δt
end
rightorth!_pullback(::Tuple{ZeroTangent,ZeroTangent}) = NoTangent(), ZeroTangent()
return (L, Q), rightorth!_pullback
end

# Corresponding matrix factorisations: implemented as mutating methods
# ---------------------------------------------------------------------
# helper routines
safe_inv(a, tol) = abs(a) < tol ? zero(a) : inv(a)

function lowertriangularind(A::AbstractMatrix)
m, n = size(A)
I = Vector{Int}(undef, div(m * (m - 1), 2) + m * (n - m))
offset = 0
for j in 1:n
r = (j + 1):m
I[offset .- j .+ r] = (j - 1) * m .+ r
offset += length(r)
end
return I
end

function uppertriangularind(A::AbstractMatrix)
m, n = size(A)
I = Vector{Int}(undef, div(m * (m - 1), 2) + m * (n - m))
offset = 0
for i in 1:m
r = (i + 1):n
I[offset .- i .+ r] = i .+ m .* (r .- 1)
offset += length(r)
end
return I
end

# SVD_pullback: pullback implementation for general (possibly truncated) SVD
#
# Arguments are U, S and Vd of full (non-truncated, but still thin) SVD, as well as
Expand All @@ -223,10 +326,10 @@ end
# Other implementation considerations for GPU compatibility:
# no scalar indexing, lots of broadcasting and views
#
safe_inv(a, tol) = abs(a) < tol ? zero(a) : inv(a)
function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector, Vd::AbstractMatrix, ΔU, ΔS, ΔVd;
atol::Real=0,
rtol::Real=atol > 0 ? 0 : eps(scalartype(S))^(3 / 4))
function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector,
Vd::AbstractMatrix, ΔU, ΔS, ΔVd;
atol::Real=0,
rtol::Real=atol > 0 ? 0 : eps(eltype(S))^(3 / 4))

# Basic size checks and determination
m, n = size(U, 1), size(Vd, 2)
Expand All @@ -245,15 +348,10 @@ function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector
end
end
if !(ΔS isa AbstractZero)
if ΔS isa AbstractMatrix
ΔSr = real(diag(ΔS))
else # ΔS isa AbstractVector
ΔSr = real(ΔS)
end
if p == -1
p = length(ΔSr)
p = length(ΔS)
else
p == length(ΔSr) || throw(DimensionMismatch())
p == length(ΔS) || throw(DimensionMismatch())
end
end
Up = view(U, :, 1:p)
Expand Down Expand Up @@ -300,7 +398,8 @@ function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector
UdΔAV = (aUΔU .+ aVΔV) .* safe_inv.(Sp' .- Sp, tol) .+
(aUΔU .- aVΔV) .* safe_inv.(Sp' .+ Sp, tol)
if !(ΔS isa ZeroTangent)
UdΔAV[diagind(UdΔAV)] .+= ΔSr
UdΔAV[diagind(UdΔAV)] .+= real.(ΔS)
# in principle, ΔS is real, but maybe not if coming from an anyonic tensor
end
mul!(ΔA, Up, UdΔAV * Vp')

Expand Down Expand Up @@ -347,41 +446,87 @@ function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector
return ΔA
end

function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos())
alg isa TensorKit.QR || alg isa TensorKit.QRpos || error("only `alg=QR()` and `alg=QRpos()` are supported")
Q, R = leftorth(t; alg)
function leftorth!_pullback((ΔQ, ΔR))
Δt = similar(t)
for (c, b) in blocks(Δt)
qr_pullback!(b, block(Q, c), block(R, c), block(ΔQ, c), block(ΔR, c))
function eig_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix, ΔD, ΔV;
atol::Real=0,
rtol::Real=atol > 0 ? 0 : eps(real(eltype(D)))^(3 / 4))

# Basic size checks and determination
n = LinearAlgebra.checksquare(V)
n == length(D) || throw(DimensionMismatch())

# tolerance and rank
tol = atol > 0 ? atol : rtol * maximum(abs, D)

if !(ΔV isa AbstractZero)
VdΔV = V' * ΔV

mask = abs.(transpose(D) .- D) .< tol
gaugepart = view(VdΔV, mask)
norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice"

VdΔV .*= conj.(safe_inv.(transpose(D) .- D, tol))

if !(ΔD isa AbstractZero)
view(VdΔV, diagind(VdΔV)) .+= ΔD
end
PΔV = V' \ VdΔV
if eltype(ΔA) <: Real
ΔAc = mul!(VdΔV, PΔV, V') # recycle VdΔV memory
ΔA .= real.(ΔAc)
else
mul!(ΔA, PΔV, V')
end
else
PΔV = V' \ Diagonal(ΔD)
if eltype(ΔA) <: Real
ΔAc = PΔV * V'
ΔA .= real.(ΔAc)
else
mul!(ΔA, PΔV, V')
end
return NoTangent(), Δt
end
leftorth!_pullback(::Tuple{ZeroTangent,ZeroTangent}) = NoTangent(), ZeroTangent()
return (Q, R), leftorth!_pullback
return ΔA
end

function ChainRulesCore.rrule(::typeof(rightorth!), t::AbstractTensorMap; alg=LQpos())
alg isa TensorKit.LQ || alg isa TensorKit.LQpos || error("only `alg=LQ()` and `alg=LQpos()` are supported")
L, Q = rightorth(t; alg)
function rightorth!_pullback((ΔL, ΔQ))
Δt = similar(t)
for (c, b) in blocks(Δt)
lq_pullback!(b, block(L, c), block(Q, c), block(ΔL, c), block(ΔQ, c))
function eigh_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix, ΔD, ΔV;
atol::Real=0,
rtol::Real=atol > 0 ? 0 : eps(real(eltype(D)))^(3 / 4))

# Basic size checks and determination
n = LinearAlgebra.checksquare(V)
n == length(D) || throw(DimensionMismatch())

# tolerance and rank
tol = atol > 0 ? atol : rtol * maximum(abs, D)

if !(ΔV isa AbstractZero)
VdΔV = V' * ΔV
aVdΔV = rmul!(VdΔV - VdΔV', 1 / 2)

mask = abs.(D' .- D) .< tol
gaugepart = view(aVdΔV, mask)
norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice"

aVdΔV .*= safe_inv.(D' .- D, tol)

if !(ΔD isa AbstractZero)
view(aVdΔV, diagind(aVdΔV)) .+= real.(ΔD)
# in principle, ΔD is real, but maybe not if coming from an anyonic tensor
end
return NoTangent(), Δt
# recylce VdΔV space
mul!(ΔA, mul!(VdΔV, V, aVdΔV), V')
else
mul!(ΔA, V * Diagonal(ΔD), V')
end
rightorth!_pullback(::Tuple{ZeroTangent,ZeroTangent}) = NoTangent(), ZeroTangent()
return (L, Q), rightorth!_pullback
return ΔA
end

function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix, ΔQ, ΔR;
atol::Real=0,
rtol::Real=atol > 0 ? 0 : eps(real(eltype(R)))^(3 / 4))

atol::Real=0,
rtol::Real=atol > 0 ? 0 : eps(real(eltype(R)))^(3 / 4))
Rd = view(R, diagind(R))
p = let tol = atol > 0 ? atol : rtol * maximum(abs, Rd)
findlast(x->abs(x)>=tol, Rd)
findlast(x -> abs(x) >= tol, Rd)
end
m, n = size(R)

Expand All @@ -407,9 +552,9 @@ function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix,
mul!(ΔA1, Q1, M, +1, 1)

if n > p
R12 = view(R, 1:p, (p+1):n)
ΔA2 = view(ΔA, :, (p+1):n)
ΔR12 = view(ΔR, 1:p, (p+1):n)
R12 = view(R, 1:p, (p + 1):n)
ΔA2 = view(ΔA, :, (p + 1):n)
ΔR12 = view(ΔR, 1:p, (p + 1):n)

if ΔR isa AbstractZero
ΔA2 .= zero(eltype(ΔA))
Expand All @@ -419,9 +564,9 @@ function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix,
end
end
if m > p && !(ΔQ isa AbstractZero) # case where R is not full rank
Q2 = view(Q, :, (p+1):m)
ΔQ2 = view(ΔQ, :, (p+1):m)
Q1dΔQ2 = Q1'*ΔQ2
Q2 = view(Q, :, (p + 1):m)
ΔQ2 = view(ΔQ, :, (p + 1):m)
Q1dΔQ2 = Q1' * ΔQ2
gaugepart = mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1)
norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice"
mul!(ΔA1, Q2, Q1dΔQ2', -1, 1)
Expand All @@ -431,12 +576,11 @@ function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix,
end

function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, ΔL, ΔQ;
atol::Real=0,
rtol::Real=atol > 0 ? 0 : eps(real(eltype(L)))^(3 / 4))

atol::Real=0,
rtol::Real=atol > 0 ? 0 : eps(real(eltype(L)))^(3 / 4))
Ld = view(L, diagind(L))
p = let tol = atol > 0 ? atol : rtol * maximum(abs, Ld)
findlast(x->abs(x)>=tol, Ld)
findlast(x -> abs(x) >= tol, Ld)
end
m, n = size(L)

Expand All @@ -462,9 +606,9 @@ function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix,
mul!(ΔA1, M, Q1, +1, 1)

if m > p
L21 = view(L, (p+1):m, 1:p)
ΔA2 = view(ΔA, (p+1):m, :)
ΔL21 = view(ΔL, (p+1):m, 1:p)
L21 = view(L, (p + 1):m, 1:p)
ΔA2 = view(ΔA, (p + 1):m, :)
ΔL21 = view(ΔL, (p + 1):m, 1:p)

if ΔL isa AbstractZero
ΔA2 .= zero(eltype(ΔA))
Expand All @@ -474,9 +618,9 @@ function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix,
end
end
if n > p && !(ΔQ isa AbstractZero) # case where R is not full rank
Q2 = view(Q, (p+1):n, :)
ΔQ2 = view(ΔQ, (p+1):n, :)
ΔQ2Q1d = ΔQ2*Q1'
Q2 = view(Q, (p + 1):n, :)
ΔQ2 = view(ΔQ, (p + 1):n, :)
ΔQ2Q1d = ΔQ2 * Q1'
gaugepart = mul!(copy(ΔQ2), ΔQ2Q1d, Q1, -1, 1)
norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice"
mul!(ΔA1, ΔQ2Q1d', Q2, -1, 1)
Expand All @@ -485,30 +629,8 @@ function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix,
return ΔA
end

function lowertriangularind(A::AbstractMatrix)
m, n = size(A)
I = Vector{Int}(undef, div(m*(m-1), 2) + m*(n-m))
offset = 0
for j = 1:n
r = (j+1):m
I[offset .- j .+ r] = (j-1)*m .+ r
offset += length(r)
end
return I
end
function uppertriangularind(A::AbstractMatrix)
m, n = size(A)
I = Vector{Int}(undef, div(m*(m-1), 2) + m*(n-m))
offset = 0
for i = 1:m
r = (i+1):n
I[offset .- i .+ r] = i .+ m .* (r .- 1)
offset += length(r)
end
return I
end


# Convert rrules
#----------------
function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{Dict}, t::AbstractTensorMap)
out = convert(Dict, t)
function convert_pullback(c)
Expand Down
Loading

0 comments on commit 6fb9d67

Please sign in to comment.