Skip to content

Commit

Permalink
Avoid allocation in InnerProduct checks
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Oct 1, 2023
1 parent 83be470 commit 699c407
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 45 deletions.
4 changes: 4 additions & 0 deletions src/spaces/vectorspaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ Return the type of inner product for vector spaces, which can be either
InnerProductStyle(V::VectorSpace) = InnerProductStyle(typeof(V))
InnerProductStyle(::Type{<:VectorSpace}) = NoInnerProduct()

@noinline function throw_invalid_innerproduct(fname)
throw(ArgumentError("$fname requires Euclidean inner product"))
end

dual(V::VectorSpace) = dual(InnerProductStyle(V), V)
dual(::EuclideanProduct, V::VectorSpace) = conj(V)

Expand Down
41 changes: 14 additions & 27 deletions src/tensors/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,38 +262,30 @@ LinearAlgebra.isposdef(t::AbstractTensorMap) = isposdef!(copy(t))
# only correct if Euclidean inner product
#------------------------------------------------------------------------------------------
function leftorth!(t::AdjointTensorMap; alg::OFA=QRpos())
InnerProductStyle(t) === EuclideanProduct() ||
throw(ArgumentError("leftorth! only defined for Euclidean inner product spaces"))
InnerProductStyle(t) === EuclideanProduct() || throw_invalid_innerproduct(:leftorth!)
return map(adjoint, reverse(rightorth!(adjoint(t); alg=alg')))
end

function rightorth!(t::AdjointTensorMap; alg::OFA=LQpos())
InnerProductStyle(t) === EuclideanProduct() ||
throw(ArgumentError("rightorth! only defined for Euclidean inner product spaces"))
InnerProductStyle(t) === EuclideanProduct() || throw_invalid_innerproduct(:rightorth!)
return map(adjoint, reverse(leftorth!(adjoint(t); alg=alg')))
end

function leftnull!(t::AdjointTensorMap; alg::OFA=QR(),
kwargs...)
InnerProductStyle(t) === EuclideanProduct() ||
throw(ArgumentError("leftnull! only defined for Euclidean inner product spaces"))
function leftnull!(t::AdjointTensorMap; alg::OFA=QR(), kwargs...)
InnerProductStyle(t) === EuclideanProduct() || throw_invalid_innerproduct(:leftnull!)
return adjoint(rightnull!(adjoint(t); alg=alg', kwargs...))
end

function rightnull!(t::AdjointTensorMap; alg::OFA=LQ(),
kwargs...)
InnerProductStyle(t) === EuclideanProduct() ||
throw(ArgumentError("rightnull! only defined for Euclidean inner product spaces"))
function rightnull!(t::AdjointTensorMap; alg::OFA=LQ(), kwargs...)
InnerProductStyle(t) === EuclideanProduct() || throw_invalid_innerproduct(:rightnull!)
return adjoint(leftnull!(adjoint(t); alg=alg', kwargs...))
end

function tsvd!(t::AdjointTensorMap;
trunc::TruncationScheme=NoTruncation(),
p::Real=2,
alg::Union{SVD,SDD}=SDD())
InnerProductStyle(t) === EuclideanProduct() ||
throw(ArgumentError("tsvd! only defined for Euclidean inner product spaces"))

InnerProductStyle(t) === EuclideanProduct() || throw_invalid_innerproduct(:tsvd!)
u, s, vt, err = tsvd!(adjoint(t); trunc=trunc, p=p, alg=alg)
return adjoint(vt), adjoint(s), adjoint(u), err
end
Expand All @@ -303,8 +295,7 @@ function leftorth!(t::TensorMap;
atol::Real=zero(float(real(scalartype(t)))),
rtol::Real=(alg (SVD(), SDD())) ? zero(float(real(scalartype(t)))) :
eps(real(float(one(scalartype(t))))) * iszero(atol))
InnerProductStyle(t) === EuclideanProduct() ||
throw(ArgumentError("leftorth! only defined for Euclidean inner product spaces"))
InnerProductStyle(t) === EuclideanProduct() || throw_invalid_innerproduct(:leftorth!)
if !iszero(rtol)
atol = max(atol, rtol * norm(t))
end
Expand Down Expand Up @@ -340,8 +331,7 @@ function leftnull!(t::TensorMap;
atol::Real=zero(float(real(scalartype(t)))),
rtol::Real=(alg (SVD(), SDD())) ? zero(float(real(scalartype(t)))) :
eps(real(float(one(scalartype(t))))) * iszero(atol))
InnerProductStyle(t) === EuclideanProduct() ||
throw(ArgumentError("leftnull! only defined for Euclidean inner product spaces"))
InnerProductStyle(t) === EuclideanProduct() || throw_invalid_innerproduct(:leftnull!)
if !iszero(rtol)
atol = max(atol, rtol * norm(t))
end
Expand All @@ -365,8 +355,7 @@ function rightorth!(t::TensorMap;
atol::Real=zero(float(real(scalartype(t)))),
rtol::Real=(alg (SVD(), SDD())) ? zero(float(real(scalartype(t)))) :
eps(real(float(one(scalartype(t))))) * iszero(atol))
InnerProductStyle(t) === EuclideanProduct() ||
throw(ArgumentError("rightorth! only defined for Euclidean inner product spaces"))
InnerProductStyle(t) === EuclideanProduct() || throw_invalid_innerproduct(:rightorth!)
if !iszero(rtol)
atol = max(atol, rtol * norm(t))
end
Expand Down Expand Up @@ -402,8 +391,7 @@ function rightnull!(t::TensorMap;
atol::Real=zero(float(real(scalartype(t)))),
rtol::Real=(alg (SVD(), SDD())) ? zero(float(real(scalartype(t)))) :
eps(real(float(one(scalartype(t))))) * iszero(atol))
InnerProductStyle(t) === EuclideanProduct() ||
throw(ArgumentError("rightnull! only defined for Euclidean inner product spaces"))
InnerProductStyle(t) === EuclideanProduct() || throw_invalid_innerproduct(:rightnull!)
if !iszero(rtol)
atol = max(atol, rtol * norm(t))
end
Expand All @@ -426,8 +414,7 @@ function tsvd!(t::TensorMap;
trunc::TruncationScheme=NoTruncation(),
p::Real=2,
alg::Union{SVD,SDD}=SDD())
InnerProductStyle(t) === EuclideanProduct() ||
throw(ArgumentError("tsvd! only defined for Euclidean inner product spaces"))
InnerProductStyle(t) === EuclideanProduct() || throw_invalid_innerproduct(:tsvd!)
S = spacetype(t)
I = sectortype(t)
A = storagetype(t)
Expand Down Expand Up @@ -500,8 +487,7 @@ end
LinearAlgebra.eigen!(t::TensorMap) = ishermitian(t) ? eigh!(t) : eig!(t)

function eigh!(t::TensorMap; kwargs...)
InnerProductStyle(t) === EuclideanProduct() ||
throw(ArgumentError("eigh! only defined for Euclidean inner product spaces"))
InnerProductStyle(t) === EuclideanProduct() || throw_invalid_innerproduct(:eigh!)
domain(t) == codomain(t) ||
throw(SpaceMismatch("`eigh!` requires domain and codomain to be the same"))
S = spacetype(t)
Expand All @@ -527,6 +513,7 @@ function eigh!(t::TensorMap; kwargs...)
end

function eig!(t::TensorMap; kwargs...)
InnerProductStyle(t) === EuclideanProduct() || throw_invalid_innerproduct(:eig!)
domain(t) == codomain(t) ||
throw(SpaceMismatch("`eig!` requires domain and codomain to be the same"))
S = spacetype(t)
Expand Down
25 changes: 9 additions & 16 deletions src/tensors/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,23 +99,19 @@ for a specific isomorphism, but the current choice is such that
`unitary(cod, dom) == inv(unitary(dom, cod)) = adjoint(unitary(dom, cod))`.
"""
function unitary(cod::TensorSpace{S}, dom::TensorSpace{S}) where {S}
InnerProductStyle(S) === EuclideanProduct() ||
throw(ArgumentError("unitary requires inner product spaces"))
InnerProductStyle(S) === EuclideanProduct() || throw_invalid_innerproduct(:unitary)
return isomorphism(cod, dom)
end
function unitary(P::TensorMapSpace{S}) where {S}
InnerProductStyle(S) === EuclideanProduct() ||
throw(ArgumentError("unitary requires inner product spaces"))
InnerProductStyle(S) === EuclideanProduct() || throw_invalid_innerproduct(:unitary)
return isomorphism(P)
end
function unitary(A::Type{<:DenseMatrix}, P::TensorMapSpace{S}) where {S}
InnerProductStyle(S) === EuclideanProduct() ||
throw(ArgumentError("unitary requires inner product spaces"))
InnerProductStyle(S) === EuclideanProduct() || throw_invalid_innerproduct(:unitary)
return isomorphism(A, P)
end
function unitary(A::Type{<:DenseMatrix}, cod::TensorSpace{S}, dom::TensorSpace{S}) where {S}
InnerProductStyle(S) === EuclideanProduct() ||
throw(ArgumentError("unitary requires inner product spaces"))
InnerProductStyle(S) === EuclideanProduct() || throw_invalid_innerproduct(:unitary)
return isomorphism(A, cod, dom)
end

Expand All @@ -139,8 +135,7 @@ end
function isometry(::Type{A},
cod::ProductSpace{S},
dom::ProductSpace{S}) where {A<:DenseMatrix,S<:ElementarySpace}
InnerProductStyle(S) === EuclideanProduct() ||
throw(ArgumentError("isometries require Euclidean inner product"))
InnerProductStyle(S) === EuclideanProduct() || throw_invalid_innerproduct(:isometry)
dom cod ||
throw(SpaceMismatch("codomain $cod and domain $dom do not allow for an isometric mapping"))
t = TensorMap(s -> A(undef, s), cod, dom)
Expand All @@ -167,10 +162,9 @@ function Base.fill!(t::AbstractTensorMap, value::Number)
end
return t
end
function LinearAlgebra.adjoint!(tdst::AbstractTensorMap,
tsrc::AbstractTensorMap)
spacetype(tdst) === spacetype(tsrc) && InnerProductStyle(tdst) === EuclideanProduct() ||
throw(ArgumentError("adjoint! requires Euclidean inner product spacetype"))
function LinearAlgebra.adjoint!(tdst::AbstractTensorMap{S},
tsrc::AbstractTensorMap{S}) where {S}
InnerProductStyle(tdst) === EuclideanProduct() || throw_invalid_innerproduct(:adjoint!)
space(tdst) == adjoint(space(tsrc)) ||
throw(SpaceMismatch("$(space(tdst)) ≠ adjoint($(space(tsrc)))"))
for c in blocksectors(tdst)
Expand Down Expand Up @@ -207,8 +201,7 @@ end
LinearAlgebra.dot(t1::AbstractTensorMap, t2::AbstractTensorMap) = inner(t1, t2)

function LinearAlgebra.norm(t::AbstractTensorMap, p::Real=2)
InnerProductStyle(t) === EuclideanProduct() ||
throw(ArgumentError("norm requires Euclidean inner product"))
InnerProductStyle(t) === EuclideanProduct() || throw_invalid_innerproduct(:norm)
return _norm(blocks(t), p, float(zero(real(scalartype(t)))))
end
function _norm(blockiter, p::Real, init::Real)
Expand Down
3 changes: 1 addition & 2 deletions src/tensors/vectorinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ end
#-------
function VectorInterface.inner(tx::AbstractTensorMap, ty::AbstractTensorMap)
space(tx) == space(ty) || throw(SpaceMismatch("$(space(tx))$(space(ty))"))
InnerProductStyle(tx) === EuclideanProduct() ||
throw(ArgumentError("dot requires Euclidean inner product"))
InnerProductStyle(tx) === EuclideanProduct() || throw_invalid_innerproduct(:inner)
T = VectorInterface.promote_inner(tx, ty)
s = zero(T)
for c in blocksectors(tx)
Expand Down

0 comments on commit 699c407

Please sign in to comment.