From c1b5109d94e2dc075f2f620d0991384d97fb3fda Mon Sep 17 00:00:00 2001 From: Jutho Date: Sat, 25 Nov 2023 00:52:53 +0100 Subject: [PATCH] refactor eig implementation --- src/auxiliary/linalg.jl | 39 ++++++++++++++++++++++++++++++++++- src/tensors/factorizations.jl | 20 +++++++----------- 2 files changed, 46 insertions(+), 13 deletions(-) diff --git a/src/auxiliary/linalg.jl b/src/auxiliary/linalg.jl index 971da973..ea04d17d 100644 --- a/src/auxiliary/linalg.jl +++ b/src/auxiliary/linalg.jl @@ -51,7 +51,7 @@ const SVDAlg = Union{SVD,SDD} module MatrixAlgebra using LinearAlgebra -using LinearAlgebra: BlasFloat, checksquare +using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, checksquare using ..TensorKit: OrthogonalFactorizationAlgorithm, QL, QLpos, QR, QRpos, LQ, LQpos, RQ, RQpos, SVD, SDD, Polar @@ -261,6 +261,43 @@ function svd!(A::StridedMatrix{<:BlasFloat}, alg::Union{SVD,SDD}) return U, S, V end +function eig!(A::StridedMatrix{T}; permute::Bool=true, scale::Bool=true) where T<:BlasReal + n = checksquare(A) + n == 0 && return zeros(Complex{T}, 0), zeros(Complex{T}, 0, 0) + + A, DR, DI, VL, VR, _ = LAPACK.geevx!(permute ? (scale ? 'B' : 'P') : (scale ? 'S' : 'N'), 'N', 'V', 'N', A) + D = complex.(DR, DI) + V = zeros(Complex{T}, n, n) + j = 1 + while j <= n + if DI[j] == 0 + V[:,j] = view(VR, :, j) + else + for i = 1:n + V[i,j] = VR[i,j] + im*VR[i,j+1] + V[i,j+1] = VR[i,j] - im*VR[i,j+1] + end + j += 1 + end + j += 1 + end + return D, V +end + +function eig!(A::StridedMatrix{T}; permute::Bool=true, scale::Bool=true) where T<:BlasComplex + n = checksquare(A) + n == 0 && return zeros(T, 0), zeros(T, 0, 0) + D, V = LAPACK.geevx!(permute ? (scale ? 'B' : 'P') : (scale ? 'S' : 'N'), 'N', 'V', 'N', A)[[2,4]] + return D, V +end + +function eigh!(A::StridedMatrix{T}) where T<:BlasFloat + n = checksquare(A) + n == 0 && return zeros(real(T), 0), zeros(T, 0, 0) + D, V = LAPACK.syevr!('V', 'A', 'U', A, 0.0, 0.0, 0, 0, -1.0) + return D, V +end + ## Old stuff and experiments # using LinearAlgebra: BlasFloat, Char, BlasInt, LAPACK, LAPACKException, diff --git a/src/tensors/factorizations.jl b/src/tensors/factorizations.jl index f2f4a76d..03bc20a7 100644 --- a/src/tensors/factorizations.jl +++ b/src/tensors/factorizations.jl @@ -158,7 +158,7 @@ decomposition is meaningless and cannot satisfy permute(t, (leftind, rightind)) * V = V * D ``` -Accepts the same keyword arguments `scale`, `permute` and `sortby` as `eigen` of dense +Accepts the same keyword arguments `scale` and `permute` as `eigen` of dense matrices. See the corresponding documentation for more information. See also `eig` and `eigh` @@ -185,7 +185,8 @@ decomposition is meaningless and cannot satisfy permute(t, (leftind, rightind)) * V = V * D ``` -Accepts the same keyword arguments `scale`, `permute` and `sortby` as `eigen` of dense matrices. See the corresponding documentation for more information. +Accepts the same keyword arguments `scale` and `permute` as `eigen` of dense +matrices. See the corresponding documentation for more information. See also `eigen` and `eigh`. """ @@ -506,7 +507,7 @@ end #--------------------------# LinearAlgebra.eigen!(t::TensorMap) = ishermitian(t) ? eigh!(t) : eig!(t) -function eigh!(t::TensorMap; kwargs...) +function eigh!(t::TensorMap) InnerProductStyle(t) === EuclideanProduct() || throw_invalid_innerproduct(:eigh!) domain(t) == codomain(t) || throw(SpaceMismatch("`eigh!` requires domain and codomain to be the same")) @@ -518,7 +519,7 @@ function eigh!(t::TensorMap; kwargs...) Vdata = SectorDict{I,A}() dims = SectorDict{I,Int}() for (c, b) in blocks(t) - values, vectors = eigen!(Hermitian(b); kwargs...) + values, vectors = MatrixAlgebra.eigh!(b) d = length(values) Ddata[c] = copyto!(similar(values, (d, d)), Diagonal(values)) Vdata[c] = vectors @@ -533,7 +534,6 @@ function eigh!(t::TensorMap; kwargs...) end function eig!(t::TensorMap; kwargs...) - InnerProductStyle(t) === EuclideanProduct() || throw_invalid_innerproduct(:eig!) domain(t) == codomain(t) || throw(SpaceMismatch("`eig!` requires domain and codomain to be the same")) S = spacetype(t) @@ -544,14 +544,10 @@ function eig!(t::TensorMap; kwargs...) Vdata = SectorDict{I,Ac}() dims = SectorDict{I,Int}() for (c, b) in blocks(t) - values, vectors = eigen!(b; kwargs...) + values, vectors = MatrixAlgebra.eig!(b; kwargs...) d = length(values) - Ddata[c] = copyto!(similar(values, T, (d, d)), Diagonal(values)) - if eltype(vectors) == T - Vdata[c] = vectors - else - Vdata[c] = copyto!(similar(vectors, T), vectors) - end + Ddata[c] = copy!(similar(values, T, (d, d)), Diagonal(values)) + Vdata[c] = vectors dims[c] = d end if length(domain(t)) == 1