Skip to content

Commit

Permalink
refactor eig implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Nov 30, 2023
1 parent 496bc25 commit c1b5109
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 13 deletions.
39 changes: 38 additions & 1 deletion src/auxiliary/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ const SVDAlg = Union{SVD,SDD}
module MatrixAlgebra

using LinearAlgebra
using LinearAlgebra: BlasFloat, checksquare
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, checksquare

using ..TensorKit: OrthogonalFactorizationAlgorithm,
QL, QLpos, QR, QRpos, LQ, LQpos, RQ, RQpos, SVD, SDD, Polar
Expand Down Expand Up @@ -261,6 +261,43 @@ function svd!(A::StridedMatrix{<:BlasFloat}, alg::Union{SVD,SDD})
return U, S, V
end

function eig!(A::StridedMatrix{T}; permute::Bool=true, scale::Bool=true) where T<:BlasReal
n = checksquare(A)
n == 0 && return zeros(Complex{T}, 0), zeros(Complex{T}, 0, 0)

A, DR, DI, VL, VR, _ = LAPACK.geevx!(permute ? (scale ? 'B' : 'P') : (scale ? 'S' : 'N'), 'N', 'V', 'N', A)
D = complex.(DR, DI)
V = zeros(Complex{T}, n, n)
j = 1
while j <= n
if DI[j] == 0
V[:,j] = view(VR, :, j)
else
for i = 1:n
V[i,j] = VR[i,j] + im*VR[i,j+1]
V[i,j+1] = VR[i,j] - im*VR[i,j+1]
end
j += 1
end
j += 1
end
return D, V
end

function eig!(A::StridedMatrix{T}; permute::Bool=true, scale::Bool=true) where T<:BlasComplex
n = checksquare(A)
n == 0 && return zeros(T, 0), zeros(T, 0, 0)
D, V = LAPACK.geevx!(permute ? (scale ? 'B' : 'P') : (scale ? 'S' : 'N'), 'N', 'V', 'N', A)[[2,4]]
return D, V
end

function eigh!(A::StridedMatrix{T}) where T<:BlasFloat
n = checksquare(A)
n == 0 && return zeros(real(T), 0), zeros(T, 0, 0)
D, V = LAPACK.syevr!('V', 'A', 'U', A, 0.0, 0.0, 0, 0, -1.0)
return D, V
end

## Old stuff and experiments

# using LinearAlgebra: BlasFloat, Char, BlasInt, LAPACK, LAPACKException,
Expand Down
20 changes: 8 additions & 12 deletions src/tensors/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ decomposition is meaningless and cannot satisfy
permute(t, (leftind, rightind)) * V = V * D
```
Accepts the same keyword arguments `scale`, `permute` and `sortby` as `eigen` of dense
Accepts the same keyword arguments `scale` and `permute` as `eigen` of dense
matrices. See the corresponding documentation for more information.
See also `eig` and `eigh`
Expand All @@ -185,7 +185,8 @@ decomposition is meaningless and cannot satisfy
permute(t, (leftind, rightind)) * V = V * D
```
Accepts the same keyword arguments `scale`, `permute` and `sortby` as `eigen` of dense matrices. See the corresponding documentation for more information.
Accepts the same keyword arguments `scale` and `permute` as `eigen` of dense
matrices. See the corresponding documentation for more information.
See also `eigen` and `eigh`.
"""
Expand Down Expand Up @@ -506,7 +507,7 @@ end
#--------------------------#
LinearAlgebra.eigen!(t::TensorMap) = ishermitian(t) ? eigh!(t) : eig!(t)

function eigh!(t::TensorMap; kwargs...)
function eigh!(t::TensorMap)
InnerProductStyle(t) === EuclideanProduct() || throw_invalid_innerproduct(:eigh!)
domain(t) == codomain(t) ||
throw(SpaceMismatch("`eigh!` requires domain and codomain to be the same"))
Expand All @@ -518,7 +519,7 @@ function eigh!(t::TensorMap; kwargs...)
Vdata = SectorDict{I,A}()
dims = SectorDict{I,Int}()
for (c, b) in blocks(t)
values, vectors = eigen!(Hermitian(b); kwargs...)
values, vectors = MatrixAlgebra.eigh!(b)
d = length(values)
Ddata[c] = copyto!(similar(values, (d, d)), Diagonal(values))
Vdata[c] = vectors
Expand All @@ -533,7 +534,6 @@ function eigh!(t::TensorMap; kwargs...)
end

function eig!(t::TensorMap; kwargs...)
InnerProductStyle(t) === EuclideanProduct() || throw_invalid_innerproduct(:eig!)
domain(t) == codomain(t) ||
throw(SpaceMismatch("`eig!` requires domain and codomain to be the same"))
S = spacetype(t)
Expand All @@ -544,14 +544,10 @@ function eig!(t::TensorMap; kwargs...)
Vdata = SectorDict{I,Ac}()
dims = SectorDict{I,Int}()
for (c, b) in blocks(t)
values, vectors = eigen!(b; kwargs...)
values, vectors = MatrixAlgebra.eig!(b; kwargs...)
d = length(values)
Ddata[c] = copyto!(similar(values, T, (d, d)), Diagonal(values))
if eltype(vectors) == T
Vdata[c] = vectors
else
Vdata[c] = copyto!(similar(vectors, T), vectors)
end
Ddata[c] = copy!(similar(values, T, (d, d)), Diagonal(values))
Vdata[c] = vectors
dims[c] = d
end
if length(domain(t)) == 1
Expand Down

0 comments on commit c1b5109

Please sign in to comment.