Skip to content

Commit

Permalink
define argmax(f, domain) for julia < v1.7
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Jan 1, 2024
1 parent 3413026 commit a028f2f
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
8 changes: 8 additions & 0 deletions src/auxiliary/auxiliary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,11 @@ function _kron(A, B)
end
return C
end

# only defined in >v1.7
@static if VERSION < v"1.7-"
_rf_findmax((fm, im), (fx, ix)) = isless(fm, fx) ? (fx, ix) : (fm, im)
_argmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain)[2]
else
_argmax = Base.argmax
end
8 changes: 4 additions & 4 deletions src/auxiliary/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,12 +273,12 @@ function eig!(A::StridedMatrix{T}; permute::Bool=true, scale::Bool=true) where {
while j <= n
if DI[j] == 0
vr = view(VR, :, j)
s = conj(sign(argmax(abs, vr)))
s = conj(sign(_argmax(abs, vr)))
V[:, j] .= s .* vr
else
vr = view(VR, :, j)
vi = view(VR, :, j + 1)
s = conj(sign(argmax(abs, vr))) # vectors coming from lapack have already real absmax component
s = conj(sign(_argmax(abs, vr))) # vectors coming from lapack have already real absmax component
V[:, j] .= s .* (vr .+ im .* vi)
V[:, j + 1] .= s .* (vr .- im .* vi)
j += 1
Expand All @@ -296,7 +296,7 @@ function eig!(A::StridedMatrix{T}; permute::Bool=true,
A)[[2, 4]]
for j in 1:n
v = view(V, :, j)
s = conj(sign(argmax(abs, v)))
s = conj(sign(_argmax(abs, v)))
v .*= s
end
return D, V
Expand All @@ -308,7 +308,7 @@ function eigh!(A::StridedMatrix{T}) where {T<:BlasFloat}
D, V = LAPACK.syevr!('V', 'A', 'U', A, 0.0, 0.0, 0, 0, -1.0)
for j in 1:n
v = view(V, :, j)
s = conj(sign(argmax(abs, v)))
s = conj(sign(_argmax(abs, v)))
v .*= s
end
return D, V
Expand Down
2 changes: 1 addition & 1 deletion test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V)
test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0))

c, = argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])), blocks(S))
c, = _argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])), blocks(S))
U, S, V, ϵ = tsvd(C; trunc=truncdim(2 * dim(c)))
ΔU = TensorMap(randn, scalartype(U), space(U))
ΔS = TensorMap(randn, scalartype(S), space(S))
Expand Down

0 comments on commit a028f2f

Please sign in to comment.