Skip to content

Commit

Permalink
restructure factorisations and matrix algebra
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Nov 17, 2023
1 parent f833a4b commit f014902
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 129 deletions.
2 changes: 1 addition & 1 deletion src/auxiliary/deprecate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import Base: transpose
@deprecate transpose(t::AbstractTensorMap, p1::IndexTuple, p2::IndexTuple; copy::Bool=false) transpose(t, (p1, p2); copy=copy)
@deprecate braid(t::AbstractTensorMap, p1::IndexTuple, p2::IndexTuple, levels; copy::Bool=false) braid(t, (p1, p2), levels; copy=copy)

import LinearAlgebra: svd, svd!

Base.@deprecate(svd(t::AbstractTensorMap, leftind::IndexTuple, rightind::IndexTuple;
trunc::TruncationScheme=notrunc(), p::Real=2, alg::SVDAlg=SDD()),
Expand All @@ -16,7 +17,6 @@ Base.@deprecate(svd!(t::AbstractTensorMap;
trunc::TruncationScheme=notrunc(), p::Real=2, alg::SVDAlg=SDD()),
tsvd(t; trunc=trunc, p=p, alg=alg))


# TODO: deprecate

tsvd(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...) = tsvd(t, (p₁, p₂); kwargs...)
Expand Down
98 changes: 58 additions & 40 deletions src/auxiliary/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,10 @@
# custom wrappers for BLAS and LAPACK routines, together with some custom definitions
using LinearAlgebra: BlasFloat, Char, BlasInt, LAPACK, LAPACKException,
DimensionMismatch, SingularException, PosDefException, chkstride1,
checksquare,
triu!

# Simple reference to getting and setting BLAS threads
#------------------------------------------------------
set_num_blas_threads(n::Integer) = LinearAlgebra.BLAS.set_num_threads(n)
get_num_blas_threads(n::Integer) = LinearAlgebra.BLAS.get_num_threads(n)

# TODO: define for CuMatrix if we support this
function _one!(A::DenseMatrix)
Threads.@threads for j in 1:size(A, 2)
@simd for i in 1:size(A, 1)
@inbounds A[i, j] = i == j
end
end
return A
end

# MATRIX factorizations
#-----------------------
# Factorization algorithms
#--------------------------
abstract type FactorizationAlgorithm end
abstract type OrthogonalFactorizationAlgorithm <: FactorizationAlgorithm end

Expand All @@ -38,12 +24,12 @@ struct RQ <: OrthogonalFactorizationAlgorithm
end
struct RQpos <: OrthogonalFactorizationAlgorithm
end
struct SDD <: OrthogonalFactorizationAlgorithm # lapack's default divide and conquer algorithm
end
struct SVD <: OrthogonalFactorizationAlgorithm
end
struct Polar <: OrthogonalFactorizationAlgorithm
end
struct SDD <: OrthogonalFactorizationAlgorithm # lapack's default divide and conquer algorithm
end

Base.adjoint(::QRpos) = LQpos()
Base.adjoint(::QR) = LQ()
Expand All @@ -57,10 +43,33 @@ Base.adjoint(::RQ) = QL()

Base.adjoint(alg::Union{SVD,SDD,Polar}) = alg

_safesign(s::Real) = ifelse(s < zero(s), -one(s), +one(s))
_safesign(s::Complex) = ifelse(iszero(s), one(s), s / abs(s))
const OFA = OrthogonalFactorizationAlgorithm
const SVDAlg = Union{SVD,SDD}

# Matrix algebra: entrypoint for calling matrix methods from within tensor implementations
#------------------------------------------------------------------------------------------
module MatrixAlgebra

using LinearAlgebra
using LinearAlgebra: BlasFloat, checksquare

function _leftorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{QR,QRpos}, atol::Real)
using ..TensorKit: OrthogonalFactorizationAlgorithm,
QL, QLpos, QR, QRpos, LQ, LQpos, RQ, RQpos, SVD, SDD, Polar

# TODO: define for CuMatrix if we support this
function one!(A::DenseMatrix)
Threads.@threads for j in 1:size(A, 2)
@simd for i in 1:size(A, 1)
@inbounds A[i, j] = i == j
end
end
return A
end

safesign(s::Real) = ifelse(s < zero(s), -one(s), +one(s))
safesign(s::Complex) = ifelse(iszero(s), one(s), s / abs(s))

function leftorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{QR,QRpos}, atol::Real)
iszero(atol) || throw(ArgumentError("nonzero atol not supported by $alg"))
m, n = size(A)
k = min(m, n)
Expand All @@ -76,21 +85,21 @@ function _leftorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{QR,QRpos}, atol::R

if isa(alg, QRpos)
@inbounds for j in 1:k
s = _safesign(R[j, j])
s = safesign(R[j, j])
@simd for i in 1:m
Q[i, j] *= s
end
end
@inbounds for j in size(R, 2):-1:1
for i in 1:min(k, j)
R[i, j] = R[i, j] * conj(_safesign(R[i, i]))
R[i, j] = R[i, j] * conj(safesign(R[i, i]))
end
end
end
return Q, R
end

function _leftorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{QL,QLpos}, atol::Real)
function leftorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{QL,QLpos}, atol::Real)
iszero(atol) || throw(ArgumentError("nonzero atol not supported by $alg"))
m, n = size(A)
@assert m >= n
Expand All @@ -100,7 +109,7 @@ function _leftorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{QL,QLpos}, atol::R
@inbounds for j in 1:nhalf, i in 1:m
A[i, j], A[i, n + 1 - j] = A[i, n + 1 - j], A[i, j]
end
Q, R = _leftorth!(A, isa(alg, QL) ? QR() : QRpos(), atol)
Q, R = leftorth!(A, isa(alg, QL) ? QR() : QRpos(), atol)

#swap columns in Q
@inbounds for j in 1:nhalf, i in 1:m
Expand All @@ -119,7 +128,7 @@ function _leftorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{QL,QLpos}, atol::R
return Q, R
end

function _leftorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{SVD,SDD,Polar}, atol::Real)
function leftorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{SVD,SDD,Polar}, atol::Real)
U, S, V = alg isa SVD ? LAPACK.gesvd!('S', 'S', A) : LAPACK.gesdd!('S', A)
if isa(alg, Union{SVD,SDD})
n = count(s -> s .> atol, S)
Expand All @@ -139,7 +148,7 @@ function _leftorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{SVD,SDD,Polar}, at
end
end

function _leftnull!(A::StridedMatrix{<:BlasFloat}, alg::Union{QR,QRpos}, atol::Real)
function leftnull!(A::StridedMatrix{<:BlasFloat}, alg::Union{QR,QRpos}, atol::Real)
iszero(atol) || throw(ArgumentError("nonzero atol not supported by $alg"))
m, n = size(A)
m >= n || throw(ArgumentError("no null space if less rows than columns"))
Expand All @@ -153,15 +162,15 @@ function _leftnull!(A::StridedMatrix{<:BlasFloat}, alg::Union{QR,QRpos}, atol::R
return N = LAPACK.gemqrt!('L', 'N', A, T, N)
end

function _leftnull!(A::StridedMatrix{<:BlasFloat}, alg::Union{SVD,SDD}, atol::Real)
size(A, 2) == 0 && return _one!(similar(A, (size(A, 1), size(A, 1))))
function leftnull!(A::StridedMatrix{<:BlasFloat}, alg::Union{SVD,SDD}, atol::Real)
size(A, 2) == 0 && return one!(similar(A, (size(A, 1), size(A, 1))))
U, S, V = alg isa SVD ? LAPACK.gesvd!('A', 'N', A) : LAPACK.gesdd!('A', A)
indstart = count(>(atol), S) + 1
return U[:, indstart:end]
end

function _rightorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{LQ,LQpos,RQ,RQpos},
atol::Real)
function rightorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{LQ,LQpos,RQ,RQpos},
atol::Real)
iszero(atol) || throw(ArgumentError("nonzero atol not supported by $alg"))
# TODO: geqrfp seems a bit slower than geqrt in the intermediate region around
# matrix size 100, which is the interesting region. => Investigate and fix
Expand All @@ -177,7 +186,7 @@ function _rightorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{LQ,LQpos,RQ,RQpos
@inbounds for j in 1:mhalf, i in 1:n
At[i, j], At[i, m + 1 - j] = At[i, m + 1 - j], At[i, j]
end
Qt, Rt = _leftorth!(At, isa(alg, RQ) ? QR() : QRpos(), atol)
Qt, Rt = leftorth!(At, isa(alg, RQ) ? QR() : QRpos(), atol)

@inbounds for j in 1:mhalf, i in 1:n
Qt[i, j], Qt[i, m + 1 - j] = Qt[i, m + 1 - j], Qt[i, j]
Expand All @@ -195,7 +204,7 @@ function _rightorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{LQ,LQpos,RQ,RQpos
R = transpose!(similar(A, (m, m)), Rt) # TODO: efficient in place
return R, Q
else
Qt, Lt = _leftorth!(At, alg', atol)
Qt, Lt = leftorth!(At, alg', atol)
if m > n
L = transpose!(A, Lt)
Q = transpose!(similar(A, (n, n)), Qt) # TODO: efficient in place
Expand All @@ -207,7 +216,7 @@ function _rightorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{LQ,LQpos,RQ,RQpos
end
end

function _rightorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{SVD,SDD,Polar}, atol::Real)
function rightorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{SVD,SDD,Polar}, atol::Real)
U, S, V = alg isa SVD ? LAPACK.gesvd!('S', 'S', A) : LAPACK.gesdd!('S', A)
if isa(alg, Union{SVD,SDD})
n = count(s -> s .> atol, S)
Expand All @@ -226,7 +235,7 @@ function _rightorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{SVD,SDD,Polar}, a
end
end

function _rightnull!(A::StridedMatrix{<:BlasFloat}, alg::Union{LQ,LQpos}, atol::Real)
function rightnull!(A::StridedMatrix{<:BlasFloat}, alg::Union{LQ,LQpos}, atol::Real)
iszero(atol) || throw(ArgumentError("nonzero atol not supported by $alg"))
m, n = size(A)
k = min(m, n)
Expand All @@ -240,18 +249,25 @@ function _rightnull!(A::StridedMatrix{<:BlasFloat}, alg::Union{LQ,LQpos}, atol::
return N = LAPACK.gemqrt!('R', eltype(At) <: Real ? 'T' : 'C', At, T, N)
end

function _rightnull!(A::StridedMatrix{<:BlasFloat}, alg::Union{SVD,SDD}, atol::Real)
size(A, 1) == 0 && return _one!(similar(A, (size(A, 2), size(A, 2))))
function rightnull!(A::StridedMatrix{<:BlasFloat}, alg::Union{SVD,SDD}, atol::Real)
size(A, 1) == 0 && return one!(similar(A, (size(A, 2), size(A, 2))))
U, S, V = alg isa SVD ? LAPACK.gesvd!('N', 'A', A) : LAPACK.gesdd!('A', A)
indstart = count(>(atol), S) + 1
return V[indstart:end, :]
end

function _svd!(A::StridedMatrix{<:BlasFloat}, alg::Union{SVD,SDD})
function svd!(A::StridedMatrix{<:BlasFloat}, alg::Union{SVD,SDD})
U, S, V = alg isa SVD ? LAPACK.gesvd!('S', 'S', A) : LAPACK.gesdd!('S', A)
return U, S, V
end

## Old stuff and experiments

# using LinearAlgebra: BlasFloat, Char, BlasInt, LAPACK, LAPACKException,
# DimensionMismatch, SingularException, PosDefException, chkstride1,
# checksquare,
# triu!

# TODO: override Julia's eig interface
# using LinearAlgebra.BLAS: @blasfunc, libblas, BlasReal, BlasComplex
# using LinearAlgebra.LAPACK: liblapack, chklapackerror
Expand Down Expand Up @@ -332,3 +348,5 @@ end
# end
# end
# end

end
Loading

0 comments on commit f014902

Please sign in to comment.