-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Automatic Differentiation #82
Merged
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
5f2c5bd
Setup ChainRules package extension
lkdvos 5657eee
Port over some methods from TensorKitAD
lkdvos 7ad26b3
Updates for leftorth and rightorth rrules
lkdvos 83454a4
Formatting
lkdvos ad7b341
Fix missing using PackageExtensionCompat
lkdvos 04a56dd
Add some missing `rrule`s.
leburgel 4cb4a73
little bit of cleanup
lkdvos 401a1d1
repartition
leburgel 868d1df
Remove overloaded `rrule`s in favor of TensorOperations update
leburgel f2b6ff0
Update VectorInterface 0.4
lkdvos 25c06e5
Various AD bugfixes
lkdvos 1c1d870
some AD tests amd updates
lkdvos a917787
Update AD rules, clean up tests
lkdvos 2eadc1e
Merge branch 'master' into ad
lkdvos 7cddde3
Include qdim in vectors
lkdvos ecfe4c4
Formatter
lkdvos 89ba53f
Merge branch 'master' into ad
lkdvos de99774
remove `thunk` in trace
lkdvos cfe42df
Merge branch 'master' into ad
lkdvos 1fe4955
Formatter
lkdvos 3139ae7
Add rrule for efficient copy constructor
lkdvos 93345d6
Fix non-existent argument name
lkdvos 28f13d0
Add ProjectTo functionality
lkdvos 67a4453
Add tensoroperations tests
lkdvos 31b6523
Changes for TensorOperations 4.0.6
lkdvos 59832c3
remove obsolete test
lkdvos File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,340 @@ | ||
module TensorKitChainRulesCoreExt | ||
|
||
using TensorOperations | ||
using TensorKit | ||
using ChainRulesCore | ||
using LinearAlgebra | ||
using TupleTools | ||
|
||
# Utility | ||
# ------- | ||
|
||
_conj(conjA::Symbol) = conjA == :C ? :N : :C | ||
trivtuple(N) = ntuple(identity, N) | ||
|
||
function _repartition(p::IndexTuple, N₁::Int) | ||
length(p) >= N₁ || | ||
throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) | ||
return p[1:N₁], p[(N₁ + 1):end] | ||
end | ||
_repartition(p::Index2Tuple, N₁::Int) = _repartition(linearize(p), N₁) | ||
function _repartition(p::Union{IndexTuple,Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} | ||
return _repartition(p, N₁) | ||
end | ||
function _repartition(p::Union{IndexTuple,Index2Tuple}, | ||
::AbstractTensorMap{<:Any,N₁}) where {N₁} | ||
return _repartition(p, N₁) | ||
end | ||
|
||
TensorKit.block(t::ZeroTangent, c::Sector) = t | ||
|
||
# Constructors | ||
# ------------ | ||
|
||
@non_differentiable TensorKit.TensorMap(f::Function, storagetype, cod, dom) | ||
lkdvos marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@non_differentiable TensorKit.isomorphism(args...) | ||
@non_differentiable TensorKit.isometry(args...) | ||
@non_differentiable TensorKit.unitary(args...) | ||
|
||
function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...) | ||
function TensorMap_pullback(Δt) | ||
∂d = convert(Array, Δt) | ||
return NoTangent(), ∂d, fill(NoTangent(), length(args))... | ||
end | ||
return TensorMap(d, args...), TensorMap_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(convert), T::Type{<:Array}, t::AbstractTensorMap) | ||
A = convert(T, t) | ||
function convert_pullback(ΔA) | ||
∂t = TensorMap(ΔA, codomain(t), domain(t)) | ||
return NoTangent(), NoTangent(), ∂t | ||
end | ||
return A, convert_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap) | ||
copy_pullback(Δt) = NoTangent(), Δt | ||
return copy(t), copy_pullback | ||
end | ||
|
||
ChainRulesCore.ProjectTo(::T) where {T<:AbstractTensorMap} = ProjectTo{T}() | ||
(::ProjectTo{T})(x::AbstractTensorMap) where {T<:AbstractTensorMap} = convert(T, x) | ||
|
||
# Base Linear Algebra | ||
# ------------------- | ||
|
||
function ChainRulesCore.rrule(::typeof(+), a::AbstractTensorMap, b::AbstractTensorMap) | ||
plus_pullback(Δc) = NoTangent(), Δc, Δc | ||
return a + b, plus_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(-), a::AbstractTensorMap, b::AbstractTensorMap) | ||
minus_pullback(Δc) = NoTangent(), Δc, -Δc | ||
return a - b, minus_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(*), a::AbstractTensorMap, b::AbstractTensorMap) | ||
times_pullback(Δc) = NoTangent(), @thunk(Δc * b'), @thunk(a' * Δc) | ||
return a * b, times_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(*), a::AbstractTensorMap, b::Number) | ||
times_pullback(Δc) = NoTangent(), @thunk(Δc * b'), @thunk(dot(a, Δc)) | ||
return a * b, times_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(*), a::Number, b::AbstractTensorMap) | ||
times_pullback(Δc) = NoTangent(), @thunk(dot(b, Δc)), @thunk(a' * Δc) | ||
return a * b, times_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(permute), tsrc::AbstractTensorMap, p::Index2Tuple) | ||
function permute_pullback(Δtdst) | ||
invp = TensorKit._canonicalize(TupleTools.invperm(linearize(p)), tsrc) | ||
return NoTangent(), permute(unthunk(Δtdst), invp), NoTangent() | ||
end | ||
return permute(tsrc, p), permute_pullback | ||
end | ||
|
||
# LinearAlgebra | ||
# ------------- | ||
|
||
function ChainRulesCore.rrule(::typeof(tr), A::AbstractTensorMap) | ||
tr_pullback(Δtr) = NoTangent(), Δtr * id(domain(A)) | ||
return tr(A), tr_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(adjoint), A::AbstractTensorMap) | ||
adjoint_pullback(Δadjoint) = NoTangent(), adjoint(unthunk(Δadjoint)) | ||
return adjoint(A), adjoint_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(dot), a::AbstractTensorMap, b::AbstractTensorMap) | ||
dot_pullback(Δd) = NoTangent(), @thunk(b * Δd'), @thunk(a * Δd) | ||
return dot(a, b), dot_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(norm), a::AbstractTensorMap, p) | ||
p == 2 || error("currently only implemented for p = 2") | ||
n = norm(a, p) | ||
norm_pullback(Δn) = NoTangent(), a * (Δn' + Δn) / (n * 2), NoTangent() | ||
return n, norm_pullback | ||
end | ||
|
||
# Factorizations | ||
# -------------- | ||
|
||
function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap; kwargs...) | ||
U, S, V, ϵ = tsvd(t; kwargs...) | ||
|
||
function tsvd!_pullback((ΔU, ΔS, ΔV, Δϵ)) | ||
∂t = similar(t) | ||
for (c, b) in blocks(∂t) | ||
copyto!(b, | ||
svd_rev(block(U, c), block(S, c), block(V, c), | ||
block(ΔU, c), block(ΔS, c), block(ΔV, c))) | ||
end | ||
|
||
return NoTangent(), ∂t | ||
end | ||
|
||
return (U, S, V, ϵ), tsvd!_pullback | ||
end | ||
|
||
""" | ||
svd_rev(U, S, V, ΔU, ΔS, ΔV; tol=eps(real(scalartype(Σ)))^(4 / 5)) | ||
|
||
Implements the following back propagation formula for the SVD: | ||
|
||
```math | ||
ΔA = UΔSV' + U(J + J')SV' + US(K + K')V' + \\frac{1}{2}US^{-1}(L' - L)V'\\ | ||
J = F ∘ (U'ΔU), \\qquad K = F ∘ (V'ΔV), \\qquad L = I ∘ (V'ΔV)\\ | ||
F_{i ≠ j} = \\frac{1}{s_j^2 - s_i^2}\\ | ||
F_{ii} = 0 | ||
``` | ||
|
||
# References | ||
|
||
Wan, Zhou-Quan, and Shi-Xin Zhang. 2019. “Automatic Differentiation for Complex Valued SVD.” https://doi.org/10.48550/ARXIV.1909.02659. | ||
""" | ||
function svd_rev(U::AbstractMatrix, S::AbstractMatrix, V::AbstractMatrix, ΔU, ΔS, ΔV; | ||
atol::Real=0, | ||
rtol::Real=atol > 0 ? 0 : eps(scalartype(S))^(3 / 4)) | ||
# project out gauge invariance dependence? | ||
# ΔU * U + ΔV * V' = 0 | ||
|
||
tol = atol > 0 ? atol : rtol * S[1, 1] | ||
F = _invert_S²(S, tol) | ||
S⁻¹ = pinv(S; atol=tol) | ||
|
||
term = Diagonal(diag(ΔS)) | ||
|
||
J = F .* (U' * ΔU) | ||
term += (J + J') * S | ||
VΔV = (V * ΔV') | ||
K = F .* VΔV | ||
term += S * (K + K') | ||
|
||
if scalartype(U) <: Complex && !(ΔV isa ZeroTangent) && !(ΔU isa ZeroTangent) | ||
L = LinearAlgebra.Diagonal(diag(VΔV)) | ||
term += 0.5 * S⁻¹ * (L' - L) | ||
end | ||
|
||
ΔA = U * term * V | ||
|
||
if size(U, 1) != size(V, 2) | ||
UUd = U * U' | ||
VdV = V' * V | ||
ΔA += (one(UUd) - UUd) * ΔU * S⁻¹ * V + U * S⁻¹ * ΔV * (one(VdV) - VdV) | ||
end | ||
|
||
return ΔA | ||
end | ||
|
||
function _invert_S²(S::AbstractMatrix{T}, tol::Real) where {T<:Real} | ||
F = similar(S) | ||
@inbounds for i in axes(F, 1), j in axes(F, 2) | ||
F[i, j] = if i == j | ||
zero(T) | ||
else | ||
sᵢ, sⱼ = S[i, i], S[j, j] | ||
1 / (abs(sⱼ - sᵢ) < tol ? tol : sⱼ^2 - sᵢ^2) | ||
end | ||
end | ||
return F | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos()) | ||
alg isa TensorKit.QR || alg isa TensorKit.QRpos || error("only QR and QRpos supported") | ||
Q, R = leftorth(t; alg) | ||
leftorth!_pullback((ΔQ, ΔR)) = NoTangent(), qr_pullback!(similar(t), t, Q, R, ΔQ, ΔR) | ||
leftorth!_pullback(::Tuple{ZeroTangent,ZeroTangent}) = ZeroTangent() | ||
return (Q, R), leftorth!_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(rightorth!), t::AbstractTensorMap; alg=LQpos()) | ||
alg isa TensorKit.LQ || alg isa TensorKit.LQpos || error("only LQ and LQpos supported") | ||
L, Q = rightorth(t; alg) | ||
rightorth!_pullback((ΔL, ΔQ)) = NoTangent(), lq_pullback!(similar(t), t, L, Q, ΔL, ΔQ) | ||
rightorth!_pullback(::Tuple{ZeroTangent,ZeroTangent}) = ZeroTangent() | ||
return (L, Q), rightorth!_pullback | ||
end | ||
|
||
""" | ||
copyltu!(A::AbstractMatrix) | ||
|
||
Copy the conjugated lower triangular part of `A` to the upper triangular part. | ||
""" | ||
function copyltu!(A::AbstractMatrix) | ||
m, n = size(A) | ||
for i in axes(A, 1) | ||
A[i, i] = real(A[i, i]) | ||
@inbounds for j in (i + 1):n | ||
A[i, j] = conj(A[j, i]) | ||
end | ||
end | ||
return A | ||
end | ||
|
||
function qr_pullback!(ΔA::AbstractTensorMap{S}, t::AbstractTensorMap{S}, | ||
Q::AbstractTensorMap{S}, R::AbstractTensorMap{S}, ΔQ, ΔR) where {S} | ||
for (c, b) in blocks(ΔA) | ||
qr_pullback!(b, block(t, c), block(Q, c), block(R, c), block(ΔQ, c), block(ΔR, c)) | ||
end | ||
return ΔA | ||
end | ||
|
||
function qr_pullback!(ΔA, A, Q::M, R::M, ΔQ, ΔR) where {M<:AbstractMatrix} | ||
m = qr_rank(R) | ||
n = size(R, 2) | ||
|
||
if n == m # full rank | ||
q = view(Q, :, 1:m) | ||
Δq = view(ΔQ, :, 1:m) | ||
r = view(R, 1:m, :) | ||
Δr = view(ΔR, 1:m, :) | ||
ΔA = qr_pullback_fullrank!(ΔA, q, r, Δq, Δr) | ||
else | ||
q = view(Q, :, 1:m) | ||
Δq = view(ΔQ, :, 1:m) + view(A, :, (m + 1):n) * view(ΔR, :, (m + 1):n)' | ||
r = view(R, 1:m, 1:m) | ||
Δr = view(ΔR, 1:m, 1:m) | ||
|
||
qr_pullback_fullrank!(view(ΔA, :, 1:m), q, r, Δq, Δr) | ||
ΔA[:, (m + 1):n] = q * view(ΔR, :, (m + 1):n) | ||
end | ||
|
||
return ΔA | ||
end | ||
|
||
function qr_pullback_fullrank!(ΔA, Q, R, ΔQ, ΔR) | ||
b = ΔQ + Q * copyltu!(R * ΔR' - ΔQ' * Q) | ||
return adjoint!(ΔA, LinearAlgebra.LAPACK.trtrs!('U', 'N', 'N', R, copy(adjoint(b)))) | ||
end | ||
|
||
function lq_pullback!(ΔA::AbstractTensorMap{S}, t::AbstractTensorMap{S}, | ||
L::AbstractTensorMap{S}, Q::AbstractTensorMap{S}, ΔL, ΔQ) where {S} | ||
for (c, b) in blocks(ΔA) | ||
lq_pullback!(b, block(t, c), block(L, c), block(Q, c), block(ΔL, c), block(ΔQ, c)) | ||
end | ||
return ΔA | ||
end | ||
|
||
function lq_pullback!(ΔA, A, L::M, Q::M, ΔL, ΔQ) where {M<:AbstractMatrix} | ||
m = qr_rank(L) | ||
n = size(L, 1) | ||
|
||
if n == m # full rank | ||
l = view(L, :, 1:m) | ||
Δl = view(ΔL, :, 1:m) | ||
q = view(Q, 1:m, :) | ||
Δq = view(ΔQ, 1:m, :) | ||
ΔA = lq_pullback_fullrank!(ΔA, l, q, Δl, Δq) | ||
else | ||
l = view(L, 1:m, 1:m) | ||
Δl = view(ΔL, 1:m, 1:m) | ||
q = view(Q, 1:m, :) | ||
Δq = view(ΔQ, 1:m, :) + view(ΔL, (m + 1):n, 1:m)' * view(A, (m + 1):n, :) | ||
|
||
lq_pullback_fullrank!(view(ΔA, 1:m, :), l, q, Δl, Δq) | ||
ΔA[(m + 1):n, :] = view(ΔL, (m + 1):n, :) * q | ||
end | ||
|
||
return ΔA | ||
end | ||
|
||
function lq_pullback_fullrank!(ΔA, L, Q, ΔL, ΔQ) | ||
mul!(ΔA, copyltu!(L' * ΔL - ΔQ * Q'), Q) | ||
axpy!(true, ΔQ, ΔA) | ||
return LinearAlgebra.LAPACK.trtrs!('L', 'C', 'N', L, ΔA) | ||
end | ||
|
||
function qr_rank(r::AbstractMatrix) | ||
Base.require_one_based_indexing(r) | ||
m, n = size(r) | ||
r₀ = r[1, 1] | ||
i = findfirst(x -> abs(x / r₀) < 1e-12, diag(r)) | ||
return isnothing(i) ? min(m, n) : i - 1 | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{Dict}, t::AbstractTensorMap) | ||
out = convert(Dict, t) | ||
function convert_pullback(c) | ||
if haskey(c, :data) # :data is the only thing for which this dual makes sense | ||
dual = copy(out) | ||
dual[:data] = c[:data] | ||
return (NoTangent(), NoTangent(), convert(TensorMap, dual)) | ||
else | ||
# instead of zero(t) you can also return ZeroTangent(), which is type unstable | ||
return (NoTangent(), NoTangent(), zero(t)) | ||
end | ||
end | ||
return out, convert_pullback | ||
end | ||
function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{TensorMap}, | ||
t::Dict{Symbol,Any}) | ||
return convert(TensorMap, t), v -> (NoTangent(), NoTangent(), convert(Dict, v)) | ||
end | ||
|
||
end |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems a bit suspicious? Why is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is mostly to avoid having to manually deal with the
ZeroTangent
type. For example, a QR decomposition where the final result does not depend on R, would generate aZeroTangent
for dR, which is just an abstract representation that behaves as the zero vector in any (co)vectorspace. As some of therrules
are implemented "blockwise", this would either require manually checking if a tangent is aZeroTangent
, or, which is what I chose to do, rely on the hope that the compiler would recognize that the blockwise operation results in ZeroTangents anyways, and thus automatically takes care of this