Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for eigvals, svdvals and diag, diagm. #130

Merged
merged 13 commits into from
Jun 26, 2024
29 changes: 29 additions & 0 deletions ext/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,20 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap;
return (U′, Σ′, V′, ϵ), tsvd!_pullback
end

function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals!), t::AbstractTensorMap)
U, S, V = tsvd(t)
s = diag(S)
project_t = ProjectTo(t)

function svdvals_pullback(Δs′)
Δs = unthunk(Δs′)
ΔS = diagm(codomain(S), domain(S), Δs)
return NoTangent(), project_t(U * ΔS * V)
end

return s, svdvals_pullback
end

function ChainRulesCore.rrule(::typeof(TensorKit.eig!), t::AbstractTensorMap; kwargs...)
D, V = eig(t; kwargs...)

Expand Down Expand Up @@ -266,6 +280,21 @@ function ChainRulesCore.rrule(::typeof(TensorKit.eigh!), t::AbstractTensorMap; k
return (D, V), eigh!_pullback
end

function ChainRulesCore.rrule(::typeof(LinearAlgebra.eigvals!), t::AbstractTensorMap;
sortby=nothing, kwargs...)
@assert sortby === nothing "only `sortby=nothing` is supported"
(D, _), eig_pullback = rrule(TensorKit.eig!, t; kwargs...)
d = diag(D)
project_t = ProjectTo(t)
function eigvals_pullback(Δd′)
Δd = unthunk(Δd′)
ΔD = diagm(codomain(D), domain(D), Δd)
return NoTangent(), project_t(eig_pullback((ΔD, ZeroTangent()))[2])
end

return d, eigvals_pullback
end

function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos())
alg isa TensorKit.QR || alg isa TensorKit.QRpos ||
error("only `alg=QR()` and `alg=QRpos()` are supported")
Expand Down
13 changes: 13 additions & 0 deletions src/tensors/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ function tsvd(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
return tsvd!(permute(t, (p₁, p₂); copy=true); kwargs...)
end

LinearAlgebra.svdvals(t::AbstractTensorMap) = LinearAlgebra.svdvals!(copy(t))
function LinearAlgebra.svdvals!(t::AbstractTensorMap)
return SectorDict(c => LinearAlgebra.svdvals!(b) for (c, b) in blocks(t))
end

"""
leftorth(t::AbstractTensorMap, (leftind, rightind)::Index2Tuple;
alg::OrthogonalFactorizationAlgorithm = QRpos()) -> Q, R
Expand Down Expand Up @@ -168,6 +173,14 @@ function LinearAlgebra.eigen(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple;
return eigen!(permute(t, (p₁, p₂); copy=true); kwargs...)
end

function LinearAlgebra.eigvals(t::AbstractTensorMap; kwargs...)
return LinearAlgebra.eigvals!(copy(t); kwargs...)
end
function LinearAlgebra.eigvals!(t::AbstractTensorMap; kwargs...)
return SectorDict(c => complex(LinearAlgebra.eigvals!(b; kwargs...))
for (c, b) in blocks(t))
end

"""
eig(t::AbstractTensor, (leftind, rightind)::Index2Tuple; kwargs...) -> D, V

Expand Down
13 changes: 13 additions & 0 deletions src/tensors/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,19 @@ function isometry(::Type{A},
return t
end

# Diagonal tensors
# ----------------
# TODO: consider adding a specialised DiagonalTensorMap type
function LinearAlgebra.diag(t::AbstractTensorMap)
return SectorDict(c => LinearAlgebra.diag(b) for (c, b) in blocks(t))
end
function LinearAlgebra.diagm(codom::VectorSpace, dom::VectorSpace, v::SectorDict)
return TensorMap(SectorDict(c => LinearAlgebra.diagm(blockdim(codom, c),
blockdim(dom, c), b)
for (c, b) in v), codom ← dom)
end
LinearAlgebra.isdiag(t::AbstractTensorMap) = all(LinearAlgebra.isdiag, values(blocks(t)))

# In-place methods
#------------------
# Wrapping the blocks in a StridedView enables multithreading if JULIA_NUM_THREADS > 1
Expand Down
38 changes: 38 additions & 0 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,27 @@ function FiniteDifferences.to_vec(t::AbstractTensorMap)
end
FiniteDifferences.to_vec(t::TensorKit.AdjointTensorMap) = to_vec(copy(t))

# make sure that norms are computed correctly:
function FiniteDifferences.to_vec(t::TensorKit.SectorDict)
T = scalartype(valtype(t))
vec = mapreduce(vcat, t; init=T[]) do (c, b)
return reshape(b, :) .* sqrt(dim(c))
end
vec_real = T <: Real ? vec : collect(reinterpret(real(T), vec))

function from_vec(x_real)
x = T <: Real ? x_real : reinterpret(T, x_real)
ctr = 0
return TensorKit.SectorDict(c => (n = length(b);
b′ = reshape(view(x, ctr .+ (1:n)), size(b)) ./
sqrt(dim(c));
ctr += n;
b′)
for (c, b) in t)
end
return vec_real, from_vec
end

function _randomize!(a::TensorMap)
for b in values(blocks(a))
copyto!(b, randn(size(b)))
Expand All @@ -68,12 +89,18 @@ end
function ChainRulesCore.rrule(::typeof(TensorKit.tsvd), args...; kwargs...)
return ChainRulesCore.rrule(tsvd!, args...; kwargs...)
end
function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals), args...; kwargs...)
return ChainRulesCore.rrule(svdvals!, args...; kwargs...)
end
function ChainRulesCore.rrule(::typeof(TensorKit.eig), args...; kwargs...)
return ChainRulesCore.rrule(eig!, args...; kwargs...)
end
function ChainRulesCore.rrule(::typeof(TensorKit.eigh), args...; kwargs...)
return ChainRulesCore.rrule(eigh!, args...; kwargs...)
end
function ChainRulesCore.rrule(::typeof(LinearAlgebra.eigvals), args...; kwargs...)
return ChainRulesCore.rrule(eigvals!, args...; kwargs...)
end
function ChainRulesCore.rrule(::typeof(TensorKit.leftorth), args...; kwargs...)
return ChainRulesCore.rrule(leftorth!, args...; kwargs...)
end
Expand Down Expand Up @@ -355,5 +382,16 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0),
fkwargs=(; trunc=truncdim(2 * dim(c))))
end

let D = LinearAlgebra.eigvals(C)
ΔD = diag(TensorMap(randn, complex(scalartype(C)), space(C)))
test_rrule(LinearAlgebra.eigvals, C; atol, output_tangent=ΔD,
fkwargs=(; sortby=nothing))
end

let S = LinearAlgebra.svdvals(C)
ΔS = diag(TensorMap(randn, real(scalartype(C)), space(C)))
test_rrule(LinearAlgebra.svdvals, C; atol, output_tangent=ΔS)
end
end
end
23 changes: 22 additions & 1 deletion test/tensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@ for V in spacelist
@test Base.promote_typeof(t, tc) == typeof(tc)
@test Base.promote_typeof(tc, t) == typeof(tc + t)
end
@timedtestset "diag/diagm" begin
W = V1 ⊗ V2 ⊗ V3 ← V4 ⊗ V5
t = TensorMap(randn, ComplexF64, W)
d = LinearAlgebra.diag(t)
D = LinearAlgebra.diagm(codomain(t), domain(t), d)
@test LinearAlgebra.isdiag(D)
@test LinearAlgebra.diag(D) == d
end
@timedtestset "Permutations: test via inner product invariance" begin
W = V1 ⊗ V2 ⊗ V3 ⊗ V4 ⊗ V5
t = Tensor(rand, ComplexF64, W)
Expand Down Expand Up @@ -408,7 +416,14 @@ for V in spacelist
@test UdU ≈ one(UdU)
VVd = V * V'
@test VVd ≈ one(VVd)
@test U * S * V ≈ permute(t, ((3, 4, 2), (1, 5)))
t2 = permute(t, ((3, 4, 2), (1, 5)))
@test U * S * V ≈ t2

s = LinearAlgebra.svdvals(t2)
s′ = LinearAlgebra.diag(S)
for (c, b) in s
@test b ≈ s′[c]
end
end
end
@testset "empty tensor" begin
Expand Down Expand Up @@ -458,6 +473,12 @@ for V in spacelist
t2 = permute(t, ((1, 3), (2, 4)))
@test t2 * V ≈ V * D

d = LinearAlgebra.eigvals(t2; sortby=nothing)
d′ = LinearAlgebra.diag(D)
for (c, b) in d
@test b ≈ d′[c]
end

# Somehow moving these test before the previous one gives rise to errors
# with T=Float32 on x86 platforms. Is this an OpenBLAS issue?
VdV = V' * V
Expand Down