Skip to content

Commit

Permalink
define argmax(f, domain) for julia < v1.7
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Jan 1, 2024
1 parent 3413026 commit ba6e558
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
16 changes: 12 additions & 4 deletions src/auxiliary/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, checksquare
using ..TensorKit: OrthogonalFactorizationAlgorithm,
QL, QLpos, QR, QRpos, LQ, LQpos, RQ, RQpos, SVD, SDD, Polar

# only defined in >v1.7
@static if VERSION < v"1.7-"
_rf_findmax((fm, im), (fx, ix)) = isless(fm, fx) ? (fx, ix) : (fm, im)
_argmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain)[2]
else
_argmax(f, domain) = argmax(f, domain)
end

# TODO: define for CuMatrix if we support this
function one!(A::DenseMatrix)
Threads.@threads for j in 1:size(A, 2)
Expand Down Expand Up @@ -273,12 +281,12 @@ function eig!(A::StridedMatrix{T}; permute::Bool=true, scale::Bool=true) where {
while j <= n
if DI[j] == 0
vr = view(VR, :, j)
s = conj(sign(argmax(abs, vr)))
s = conj(sign(_argmax(abs, vr)))
V[:, j] .= s .* vr
else
vr = view(VR, :, j)
vi = view(VR, :, j + 1)
s = conj(sign(argmax(abs, vr))) # vectors coming from lapack have already real absmax component
s = conj(sign(_argmax(abs, vr))) # vectors coming from lapack have already real absmax component
V[:, j] .= s .* (vr .+ im .* vi)
V[:, j + 1] .= s .* (vr .- im .* vi)
j += 1
Expand All @@ -296,7 +304,7 @@ function eig!(A::StridedMatrix{T}; permute::Bool=true,
A)[[2, 4]]
for j in 1:n
v = view(V, :, j)
s = conj(sign(argmax(abs, v)))
s = conj(sign(_argmax(abs, v)))
v .*= s
end
return D, V
Expand All @@ -308,7 +316,7 @@ function eigh!(A::StridedMatrix{T}) where {T<:BlasFloat}
D, V = LAPACK.syevr!('V', 'A', 'U', A, 0.0, 0.0, 0, 0, -1.0)
for j in 1:n
v = view(V, :, j)
s = conj(sign(argmax(abs, v)))
s = conj(sign(_argmax(abs, v)))
v .*= s
end
return D, V
Expand Down
2 changes: 1 addition & 1 deletion test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V)
test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0))

c, = argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])), blocks(S))
c, = TensorKit.MatrixAlgebra._argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])), blocks(S))
U, S, V, ϵ = tsvd(C; trunc=truncdim(2 * dim(c)))
ΔU = TensorMap(randn, scalartype(U), space(U))
ΔS = TensorMap(randn, scalartype(S), space(S))
Expand Down

0 comments on commit ba6e558

Please sign in to comment.