Skip to content

Commit

Permalink
use diagonaltensormap (#190)
Browse files Browse the repository at this point in the history
* use diagonaltensormap

* streamline type and copy
  • Loading branch information
Jutho authored Dec 19, 2024
1 parent ef42572 commit a8aa774
Showing 1 changed file with 136 additions and 70 deletions.
206 changes: 136 additions & 70 deletions src/tensors/factorizations.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
# Tensor factorization
#----------------------
function factorisation_scalartype(t::AbstractTensorMap)
T = scalartype(t)
return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T)))))
end
factorisation_scalartype(f, t) = factorisation_scalartype(t)

function permutedcopy_oftype(t::AbstractTensorMap, T::Type{<:Number}, p::Index2Tuple)
return permute!(similar(t, T, permute(space(t), p)), t, p)
end
function copy_oftype(t::AbstractTensorMap, T::Type{<:Number})
return copy!(similar(t, T), t)
end

"""
tsvd(t::AbstractTensorMap, (leftind, rightind)::Index2Tuple;
trunc::TruncationScheme = notrunc(), p::Real = 2, alg::Union{SVD, SDD} = SDD())
Expand Down Expand Up @@ -36,13 +49,14 @@ algorithm that computes the decomposition (`_gesvd` or `_gesdd`).
Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and `tsvd(!)`
is currently only implemented for `InnerProductStyle(t) === EuclideanInnerProduct()`.
"""
function tsvd(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
return tsvd!(permute(t, (p₁, p₂); copy=true); kwargs...)
function tsvd(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
tcopy = permutedcopy_oftype(t, factorisation_scalartype(tsvd, t), p)
return tsvd!(tcopy; kwargs...)
end

LinearAlgebra.svdvals(t::AbstractTensorMap) = LinearAlgebra.svdvals!(copy(t))
function LinearAlgebra.svdvals!(t::AbstractTensorMap)
return SectorDict(c => LinearAlgebra.svdvals!(b) for (c, b) in blocks(t))
function LinearAlgebra.svdvals(t::AbstractTensorMap)
tcopy = copy_oftype(t, factorisation_scalartype(tsvd, t))
return LinearAlgebra.svdvals!(tcopy)
end

"""
Expand All @@ -67,8 +81,9 @@ Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and
`leftorth(!)` is currently only implemented for
`InnerProductStyle(t) === EuclideanInnerProduct()`.
"""
function leftorth(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
return leftorth!(permute(t, (p₁, p₂); copy=true); kwargs...)
function leftorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
tcopy = permutedcopy_oftype(t, factorisation_scalartype(leftorth, t), p)
return leftorth!(tcopy; kwargs...)
end

"""
Expand All @@ -95,8 +110,9 @@ Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and
`rightorth(!)` is currently only implemented for
`InnerProductStyle(t) === EuclideanInnerProduct()`.
"""
function rightorth(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
return rightorth!(permute(t, (p₁, p₂); copy=true); kwargs...)
function rightorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
tcopy = permutedcopy_oftype(t, factorisation_scalartype(rightorth, t), p)
return rightorth!(tcopy; kwargs...)
end

"""
Expand All @@ -121,8 +137,9 @@ Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and
`leftnull(!)` is currently only implemented for
`InnerProductStyle(t) === EuclideanInnerProduct()`.
"""
function leftnull(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
return leftnull!(permute(t, (p₁, p₂); copy=true); kwargs...)
function leftnull(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
tcopy = permutedcopy_oftype(t, factorisation_scalartype(leftnull, t), p)
return leftnull!(tcopy; kwargs...)
end

"""
Expand All @@ -149,8 +166,9 @@ Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and
`rightnull(!)` is currently only implemented for
`InnerProductStyle(t) === EuclideanInnerProduct()`.
"""
function rightnull(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
return rightnull!(permute(t, (p₁, p₂); copy=true); kwargs...)
function rightnull(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
tcopy = permutedcopy_oftype(t, factorisation_scalartype(rightnull, t), p)
return rightnull!(tcopy; kwargs...)
end

"""
Expand All @@ -172,17 +190,14 @@ matrices. See the corresponding documentation for more information.
See also `eig` and `eigh`
"""
function LinearAlgebra.eigen(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple;
kwargs...)
return eigen!(permute(t, (p₁, p₂); copy=true); kwargs...)
function LinearAlgebra.eigen(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
tcopy = permutedcopy_oftype(t, factorisation_scalartype(eigen, t), p)
return eigen!(tcopy; kwargs...)
end

function LinearAlgebra.eigvals(t::AbstractTensorMap; kwargs...)
return LinearAlgebra.eigvals!(copy(t); kwargs...)
end
function LinearAlgebra.eigvals!(t::AbstractTensorMap; kwargs...)
return SectorDict(c => complex(LinearAlgebra.eigvals!(b; kwargs...))
for (c, b) in blocks(t))
tcopy = copy_oftype(t, factorisation_scalartype(eigen, t))
return LinearAlgebra.eigvals!(tcopy; kwargs...)
end

"""
Expand All @@ -207,8 +222,9 @@ matrices. See the corresponding documentation for more information.
See also `eigen` and `eigh`.
"""
function eig(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
return eig!(permute(t, (p₁, p₂); copy=true); kwargs...)
function eig(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
tcopy = permutedcopy_oftype(t, factorisation_scalartype(eig, t), p)
return eig!(tcopy; kwargs...)
end

"""
Expand All @@ -231,8 +247,9 @@ permute(t, (leftind, rightind)) * V = V * D
See also `eigen` and `eig`.
"""
function eigh(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple)
return eigh!(permute(t, (p₁, p₂); copy=true))
function eigh(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
tcopy = permutedcopy_oftype(t, factorisation_scalartype(eigh, t), p)
return eigh!(tcopy; kwargs...)
end

"""
Expand All @@ -247,31 +264,54 @@ which `isposdef!` is called should have equal domain and codomain, as otherwise
meaningless.
"""
function LinearAlgebra.isposdef(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple)
return isposdef!(permute(t, (p₁, p₂); copy=true))
tcopy = permutedcopy_oftype(t, factorisation_scalartype(isposdef, t), p)
return isposdef!(tcopy)
end

tsvd(t::AbstractTensorMap; kwargs...) = tsvd!(copy(t); kwargs...)
function tsvd(t::AbstractTensorMap; kwargs...)
tcopy = copy!(similar(t, float(scalartype(t))), t)
return tsvd!(tcopy; kwargs...)
end
function leftorth(t::AbstractTensorMap; alg::OFA=QRpos(), kwargs...)
return leftorth!(copy(t); alg=alg, kwargs...)
tcopy = copy!(similar(t, float(scalartype(t))), t)
return leftorth!(tcopy; alg=alg, kwargs...)
end
function rightorth(t::AbstractTensorMap; alg::OFA=LQpos(), kwargs...)
return rightorth!(copy(t); alg=alg, kwargs...)
tcopy = copy!(similar(t, float(scalartype(t))), t)
return rightorth!(tcopy; alg=alg, kwargs...)
end
function leftnull(t::AbstractTensorMap; alg::OFA=QR(), kwargs...)
return leftnull!(copy(t); alg=alg, kwargs...)
tcopy = copy!(similar(t, float(scalartype(t))), t)
return leftnull!(tcopy; alg=alg, kwargs...)
end
function rightnull(t::AbstractTensorMap; alg::OFA=LQ(), kwargs...)
return rightnull!(copy(t); alg=alg, kwargs...)
tcopy = copy!(similar(t, float(scalartype(t))), t)
return rightnull!(tcopy; alg=alg, kwargs...)
end
function LinearAlgebra.eigen(t::AbstractTensorMap; kwargs...)
tcopy = copy!(similar(t, float(scalartype(t))), t)
return eigen!(tcopy; kwargs...)
end
function eig(t::AbstractTensorMap; kwargs...)
tcopy = copy!(similar(t, float(scalartype(t))), t)
return eig!(tcopy; kwargs...)
end
function eigh(t::AbstractTensorMap; kwargs...)
tcopy = copy!(similar(t, float(scalartype(t))), t)
return eigh!(tcopy; kwargs...)
end
function LinearAlgebra.isposdef(t::AbstractTensorMap)
tcopy = copy!(similar(t, float(scalartype(t))), t)
return isposdef!(tcopy)
end
LinearAlgebra.eigen(t::AbstractTensorMap; kwargs...) = eigen!(copy(t); kwargs...)
eig(t::AbstractTensorMap; kwargs...) = eig!(copy(t); kwargs...)
eigh(t::AbstractTensorMap; kwargs...) = eigh!(copy(t); kwargs...)
LinearAlgebra.isposdef(t::AbstractTensorMap) = isposdef!(copy(t))

# Orthogonal factorizations (mutation for recycling memory):
# only possible if scalar type is floating point
# only correct if Euclidean inner product
#------------------------------------------------------------------------------------------
function leftorth!(t::TensorMap;
const RealOrComplexFloat = Union{AbstractFloat,Complex{<:AbstractFloat}}

function leftorth!(t::TensorMap{<:RealOrComplexFloat};
alg::Union{QR,QRpos,QL,QLpos,SVD,SDD,Polar}=QRpos(),
atol::Real=zero(float(real(scalartype(t)))),
rtol::Real=(alg (SVD(), SDD())) ? zero(float(real(scalartype(t)))) :
Expand Down Expand Up @@ -321,7 +361,7 @@ function leftorth!(t::TensorMap;
return Q, R
end

function leftnull!(t::TensorMap;
function leftnull!(t::TensorMap{<:RealOrComplexFloat};
alg::Union{QR,QRpos,SVD,SDD}=QRpos(),
atol::Real=zero(float(real(scalartype(t)))),
rtol::Real=(alg (SVD(), SDD())) ? zero(float(real(scalartype(t)))) :
Expand Down Expand Up @@ -360,7 +400,7 @@ function leftnull!(t::TensorMap;
return N
end

function rightorth!(t::TensorMap;
function rightorth!(t::TensorMap{<:RealOrComplexFloat};
alg::Union{LQ,LQpos,RQ,RQpos,SVD,SDD,Polar}=LQpos(),
atol::Real=zero(float(real(scalartype(t)))),
rtol::Real=(alg (SVD(), SDD())) ? zero(float(real(scalartype(t)))) :
Expand Down Expand Up @@ -410,7 +450,7 @@ function rightorth!(t::TensorMap;
return L, Q
end

function rightnull!(t::TensorMap;
function rightnull!(t::TensorMap{<:RealOrComplexFloat};
alg::Union{LQ,LQpos,SVD,SDD}=LQpos(),
atol::Real=zero(float(real(scalartype(t)))),
rtol::Real=(alg (SVD(), SDD())) ? zero(float(real(scalartype(t)))) :
Expand Down Expand Up @@ -476,7 +516,13 @@ end
#------------------------------#
# Singular value decomposition #
#------------------------------#
function tsvd!(t::TensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD())
function LinearAlgebra.svdvals!(t::TensorMap{<:RealOrComplexFloat})
return SectorDict(c => LinearAlgebra.svdvals!(b) for (c, b) in blocks(t))
end
LinearAlgebra.svdvals!(t::AdjointTensorMap) = svdvals!(adjoint(t))

function tsvd!(t::TensorMap{<:RealOrComplexFloat};
trunc=NoTruncation(), p::Real=2, alg=SDD())
return _tsvd!(t, alg, trunc, p)
end
function tsvd!(t::AdjointTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD())
Expand All @@ -485,7 +531,8 @@ function tsvd!(t::AdjointTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD())
end

# implementation dispatches on algorithm
function _tsvd!(t, alg::Union{SVD,SDD}, trunc::TruncationScheme, p::Real=2)
function _tsvd!(t::TensorMap{<:RealOrComplexFloat}, alg::Union{SVD,SDD},
trunc::TruncationScheme, p::Real=2)
# early return
if isempty(blocksectors(t))
truncerr = zero(real(scalartype(t)))
Expand Down Expand Up @@ -518,13 +565,17 @@ function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD})
return SVDdata, dims
end

function _create_svdtensors(t, SVDdata, dims)
function _create_svdtensors(t::TensorMap{<:RealOrComplexFloat}, SVDdata, dims)
T = scalartype(t)
S = spacetype(t)
W = S(dims)
T = float(scalartype(t))
U = similar(t, T, codomain(t) W)
Σ = similar(t, real(T), W W)
V⁺ = similar(t, T, W domain(t))

Tr = real(T)
A = similarstoragetype(t, Tr)
Σ = DiagonalTensorMap{Tr,S,A}(undef, W)

U = similar(t, codomain(t) W)
V⁺ = similar(t, W domain(t))
for (c, (Uc, Σc, V⁺c)) in SVDdata
r = Base.OneTo(dims[c])
copy!(block(U, c), view(Uc, :, r))
Expand All @@ -534,38 +585,53 @@ function _create_svdtensors(t, SVDdata, dims)
return U, Σ, V⁺
end

function _empty_svdtensors(t)
function _empty_svdtensors(t::TensorMap{<:RealOrComplexFloat})
T = scalartype(t)
S = spacetype(t)
I = sectortype(t)
dims = SectorDict{I,Int}()
S = spacetype(t)
W = S(dims)

Tr = real(T)
A = similarstoragetype(t, Tr)
Σ = DiagonalTensorMap{Tr,S,A}(undef, W)

U = similar(t, codomain(t) W)
Σ = similar(t, real(scalartype(t)), W W)
V⁺ = similar(t, W domain(t))
return U, Σ, V⁺
end

#--------------------------#
# Eigenvalue decomposition #
#--------------------------#
LinearAlgebra.eigen!(t::TensorMap) = ishermitian(t) ? eigh!(t) : eig!(t)
function LinearAlgebra.eigen!(t::TensorMap{<:RealOrComplexFloat})
return ishermitian(t) ? eigh!(t) : eig!(t)
end

function LinearAlgebra.eigvals!(t::TensorMap{<:RealOrComplexFloat}; kwargs...)
return SectorDict(c => complex(LinearAlgebra.eigvals!(b; kwargs...))
for (c, b) in blocks(t))
end
function LinearAlgebra.eigvals!(t::AdjointTensorMap{<:RealOrComplexFloat}; kwargs...)
return SectorDict(c => conj!(complex(LinearAlgebra.eigvals!(b; kwargs...)))
for (c, b) in blocks(t))
end

function eigh!(t::TensorMap)
function eigh!(t::TensorMap{<:RealOrComplexFloat})
InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:eigh!)
domain(t) == codomain(t) ||
throw(SpaceMismatch("`eigh!` requires domain and codomain to be the same"))

T = scalartype(t)
I = sectortype(t)
S = spacetype(t)
dims = SectorDict{I,Int}(c => size(b, 1) for (c, b) in blocks(t))
if length(domain(t)) == 1
W = domain(t)[1]
else
S = spacetype(t)
W = S(dims)
end
T = float(scalartype(t))
V = similar(t, T, domain(t) W)
D = similar(t, real(T), W W)
W = S(dims)

Tr = real(T)
A = similarstoragetype(t, Tr)
D = DiagonalTensorMap{Tr,S,A}(undef, W)
V = similar(t, domain(t) W)
for (c, b) in blocks(t)
values, vectors = MatrixAlgebra.eigh!(b)
copy!(block(D, c), Diagonal(values))
Expand All @@ -574,20 +640,20 @@ function eigh!(t::TensorMap)
return D, V
end

function eig!(t::TensorMap; kwargs...)
function eig!(t::TensorMap{<:RealOrComplexFloat}; kwargs...)
domain(t) == codomain(t) ||
throw(SpaceMismatch("`eig!` requires domain and codomain to be the same"))

T = scalartype(t)
I = sectortype(t)
S = spacetype(t)
dims = SectorDict{I,Int}(c => size(b, 1) for (c, b) in blocks(t))
if length(domain(t)) == 1
W = domain(t)[1]
else
S = spacetype(t)
W = S(dims)
end
T = complex(float(scalartype(t)))
V = similar(t, T, domain(t) W)
D = similar(t, T, W W)
W = S(dims)

Tc = complex(T)
A = similarstoragetype(t, Tc)
D = DiagonalTensorMap{Tc,S,A}(undef, W)
V = similar(t, Tc, domain(t) W)
for (c, b) in blocks(t)
values, vectors = MatrixAlgebra.eig!(b; kwargs...)
copy!(block(D, c), Diagonal(values))
Expand Down

0 comments on commit a8aa774

Please sign in to comment.