Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for eigvals, svdvals and diag, diagm. #130

Merged
merged 13 commits into from
Jun 26, 2024
29 changes: 29 additions & 0 deletions ext/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,20 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap;
return (U′, Σ′, V′, ϵ), tsvd!_pullback
end

function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals!), t::AbstractTensorMap)
U, S, V = tsvd(t)
s = diag(S)
project_t = ProjectTo(t)

function svdvals_pullback(Δs′)
Δs = unthunk(Δs′)
ΔS = diagm(codomain(S), domain(S), Δs)
return NoTangent(), project_t(U * ΔS * V)
end

return s, svdvals_pullback
end

function ChainRulesCore.rrule(::typeof(TensorKit.eig!), t::AbstractTensorMap; kwargs...)
D, V = eig(t; kwargs...)

Expand Down Expand Up @@ -264,6 +278,21 @@ function ChainRulesCore.rrule(::typeof(TensorKit.eigh!), t::AbstractTensorMap; k
return (D, V), eigh!_pullback
end

function ChainRulesCore.rrule(::typeof(LinearAlgebra.eigvals!), t::AbstractTensorMap;
sortby=nothing, kwargs...)
@assert sortby === nothing "only `sortby=nothing` is supported"
(D, _), eig_pullback = rrule(TensorKit.eig!, t; kwargs...)
d = diag(D)
project_t = ProjectTo(t)
function eigvals_pullback(Δd′)
Δd = unthunk(Δd′)
ΔD = diagm(codomain(D), domain(D), Δd)
return NoTangent(), project_t(eig_pullback((ΔD, ZeroTangent()))[2])
end

return d, eigvals_pullback
end

function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos())
alg isa TensorKit.QR || alg isa TensorKit.QRpos ||
error("only `alg=QR()` and `alg=QRpos()` are supported")
Expand Down
23 changes: 23 additions & 0 deletions src/tensors/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ function tsvd(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
return tsvd!(permute(t, (p₁, p₂); copy=true); kwargs...)
end

LinearAlgebra.svdvals(t::AbstractTensorMap) = LinearAlgebra.svdvals!(copy(t))
function LinearAlgebra.svdvals!(t::AbstractTensorMap)
return SectorDict(c => LinearAlgebra.svdvals!(b) for (c, b) in blocks(t))
end

# TODO: decide if we want to keep these specializations:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In line of the plans to stop the special casing of TrivialTensorMap, I would support to remove these specialisations.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I definitely can, but in order to keep this consistent with diag and diagm, this would then also return a dictionary object for TrivialTensorMap. I don't have a big opinion on this, as I usually write code for generic symmetries and have to deal with these cases separately anyways, but I just want to mention that this is a bit connected here.

function LinearAlgebra.svdvals!(t::TrivialTensorMap)
return LinearAlgebra.svdvals!(t.data)
end

"""
leftorth(t::AbstractTensorMap, (leftind, rightind)::Index2Tuple;
alg::OrthogonalFactorizationAlgorithm = QRpos()) -> Q, R
Expand Down Expand Up @@ -168,6 +178,19 @@ function LinearAlgebra.eigen(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple;
return eigen!(permute(t, (p₁, p₂); copy=true); kwargs...)
end

function LinearAlgebra.eigvals(t::AbstractTensorMap; kwargs...)
return LinearAlgebra.eigvals!(copy(t); kwargs...)
end
function LinearAlgebra.eigvals!(t::AbstractTensorMap; kwargs...)
return SectorDict(c => complex(LinearAlgebra.eigvals!(b; kwargs...))
for (c, b) in blocks(t))
end

# TODO: decide if we want to keep these specializations:
function LinearAlgebra.eigvals!(t::TrivialTensorMap; kwargs...)
return complex(LinearAlgebra.eigvals!(t.data; kwargs...))
end

"""
eig(t::AbstractTensor, (leftind, rightind)::Index2Tuple; kwargs...) -> D, V

Expand Down
24 changes: 24 additions & 0 deletions src/tensors/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,30 @@ function isometry(::Type{A},
return t
end

# Diagonal tensors
# ----------------
# TODO: consider adding a specialised DiagonalTensorMap type
function LinearAlgebra.diag(t::AbstractTensorMap)
if sectortype(t) === Trivial
return LinearAlgebra.diag(block(t, Trivial()))
else
return SectorDict(c => LinearAlgebra.diag(b) for (c, b) in blocks(t))
end
end

function LinearAlgebra.diagm(codom::VectorSpace, dom::VectorSpace, v::AbstractVector)
sp = codom ← dom
@assert sectortype(sp) === Trivial
return TensorMap(LinearAlgebra.diagm(dim(codom), dim(dom), v), sp)
end
function LinearAlgebra.diagm(codom::VectorSpace, dom::VectorSpace, v::SectorDict)
return TensorMap(SectorDict(c => LinearAlgebra.diagm(blockdim(codom, c),
blockdim(dom, c), b)
for (c, b) in v), codom ← dom)
end

LinearAlgebra.isdiag(t::AbstractTensorMap) = all(LinearAlgebra.isdiag, values(blocks(t)))

# In-place methods
#------------------
# Wrapping the blocks in a StridedView enables multithreading if JULIA_NUM_THREADS > 1
Expand Down
38 changes: 38 additions & 0 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,27 @@ function FiniteDifferences.to_vec(t::AbstractTensorMap)
end
FiniteDifferences.to_vec(t::TensorKit.AdjointTensorMap) = to_vec(copy(t))

# make sure that norms are computed correctly:
function FiniteDifferences.to_vec(t::TensorKit.SectorDict)
T = scalartype(valtype(t))
vec = mapreduce(vcat, t; init=T[]) do (c, b)
return reshape(b, :) .* sqrt(dim(c))
end
vec_real = T <: Real ? vec : collect(reinterpret(real(T), vec))

function from_vec(x_real)
x = T <: Real ? x_real : reinterpret(T, x_real)
ctr = 0
return TensorKit.SectorDict(c => (n = length(b);
b′ = reshape(view(x, ctr .+ (1:n)), size(b)) ./
sqrt(dim(c));
ctr += n;
b′)
for (c, b) in t)
end
return vec_real, from_vec
end

function _randomize!(a::TensorMap)
for b in values(blocks(a))
copyto!(b, randn(size(b)))
Expand All @@ -68,12 +89,18 @@ end
function ChainRulesCore.rrule(::typeof(TensorKit.tsvd), args...; kwargs...)
return ChainRulesCore.rrule(tsvd!, args...; kwargs...)
end
function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals), args...; kwargs...)
return ChainRulesCore.rrule(svdvals!, args...; kwargs...)
end
function ChainRulesCore.rrule(::typeof(TensorKit.eig), args...; kwargs...)
return ChainRulesCore.rrule(eig!, args...; kwargs...)
end
function ChainRulesCore.rrule(::typeof(TensorKit.eigh), args...; kwargs...)
return ChainRulesCore.rrule(eigh!, args...; kwargs...)
end
function ChainRulesCore.rrule(::typeof(LinearAlgebra.eigvals), args...; kwargs...)
return ChainRulesCore.rrule(eigvals!, args...; kwargs...)
end
function ChainRulesCore.rrule(::typeof(TensorKit.leftorth), args...; kwargs...)
return ChainRulesCore.rrule(leftorth!, args...; kwargs...)
end
Expand Down Expand Up @@ -355,5 +382,16 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0),
fkwargs=(; trunc=truncdim(2 * dim(c))))
end

let D = LinearAlgebra.eigvals(C)
ΔD = diag(TensorMap(randn, complex(scalartype(C)), space(C)))
test_rrule(LinearAlgebra.eigvals, C; atol, output_tangent=ΔD,
fkwargs=(; sortby=nothing))
end

let S = LinearAlgebra.svdvals(C)
ΔS = diag(TensorMap(randn, real(scalartype(C)), space(C)))
test_rrule(LinearAlgebra.svdvals, C; atol, output_tangent=ΔS)
end
end
end
31 changes: 30 additions & 1 deletion test/tensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@ for V in spacelist
@test Base.promote_typeof(t, tc) == typeof(tc)
@test Base.promote_typeof(tc, t) == typeof(tc + t)
end
@timedtestset "diag/diagm" begin
W = V1 ⊗ V2 ⊗ V3 ← V4 ⊗ V5
t = TensorMap(randn, ComplexF64, W)
d = LinearAlgebra.diag(t)
D = LinearAlgebra.diagm(codomain(t), domain(t), d)
@test LinearAlgebra.isdiag(D)
@test LinearAlgebra.diag(D) == d
end
@timedtestset "Permutations: test via inner product invariance" begin
W = V1 ⊗ V2 ⊗ V3 ⊗ V4 ⊗ V5
t = Tensor(rand, ComplexF64, W)
Expand Down Expand Up @@ -408,7 +416,18 @@ for V in spacelist
@test UdU ≈ one(UdU)
VVd = V * V'
@test VVd ≈ one(VVd)
@test U * S * V ≈ permute(t, ((3, 4, 2), (1, 5)))
t2 = permute(t, ((3, 4, 2), (1, 5)))
@test U * S * V ≈ t2

s = LinearAlgebra.svdvals(t2)
s′ = LinearAlgebra.diag(S)
if s isa TensorKit.SectorDict
for (c, b) in s
@test b ≈ s′[c]
end
else
@test s ≈ s′
end
end
end
@testset "empty tensor" begin
Expand Down Expand Up @@ -458,6 +477,16 @@ for V in spacelist
t2 = permute(t, ((1, 3), (2, 4)))
@test t2 * V ≈ V * D

d = LinearAlgebra.eigvals(t2; sortby=nothing)
d′ = LinearAlgebra.diag(D)
if d isa TensorKit.SectorDict
for (c, b) in d
@test b ≈ d′[c]
end
else
@test d ≈ d′
end

# Somehow moving these test before the previous one gives rise to errors
# with T=Float32 on x86 platforms. Is this an OpenBLAS issue?
VdV = V' * V
Expand Down