From a028f2fd50e0f7eed63f2745d889f8677add47a3 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Mon, 1 Jan 2024 16:55:22 +0100 Subject: [PATCH] define `argmax(f, domain)` for julia < v1.7 --- src/auxiliary/auxiliary.jl | 8 ++++++++ src/auxiliary/linalg.jl | 8 ++++---- test/ad.jl | 2 +- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/auxiliary/auxiliary.jl b/src/auxiliary/auxiliary.jl index f407f968..a00800a9 100644 --- a/src/auxiliary/auxiliary.jl +++ b/src/auxiliary/auxiliary.jl @@ -40,3 +40,11 @@ function _kron(A, B) end return C end + +# only defined in >v1.7 +@static if VERSION < v"1.7-" + _rf_findmax((fm, im), (fx, ix)) = isless(fm, fx) ? (fx, ix) : (fm, im) + _argmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain)[2] +else + _argmax = Base.argmax +end \ No newline at end of file diff --git a/src/auxiliary/linalg.jl b/src/auxiliary/linalg.jl index bc57d4c6..02cf503d 100644 --- a/src/auxiliary/linalg.jl +++ b/src/auxiliary/linalg.jl @@ -273,12 +273,12 @@ function eig!(A::StridedMatrix{T}; permute::Bool=true, scale::Bool=true) where { while j <= n if DI[j] == 0 vr = view(VR, :, j) - s = conj(sign(argmax(abs, vr))) + s = conj(sign(_argmax(abs, vr))) V[:, j] .= s .* vr else vr = view(VR, :, j) vi = view(VR, :, j + 1) - s = conj(sign(argmax(abs, vr))) # vectors coming from lapack have already real absmax component + s = conj(sign(_argmax(abs, vr))) # vectors coming from lapack have already real absmax component V[:, j] .= s .* (vr .+ im .* vi) V[:, j + 1] .= s .* (vr .- im .* vi) j += 1 @@ -296,7 +296,7 @@ function eig!(A::StridedMatrix{T}; permute::Bool=true, A)[[2, 4]] for j in 1:n v = view(V, :, j) - s = conj(sign(argmax(abs, v))) + s = conj(sign(_argmax(abs, v))) v .*= s end return D, V @@ -308,7 +308,7 @@ function eigh!(A::StridedMatrix{T}) where {T<:BlasFloat} D, V = LAPACK.syevr!('V', 'A', 'U', A, 0.0, 0.0, 0, 0, -1.0) for j in 1:n v = view(V, :, j) - s = conj(sign(argmax(abs, v))) + s = conj(sign(_argmax(abs, v))) v .*= s end return D, V diff --git a/test/ad.jl b/test/ad.jl index 3b82c0b4..878031e5 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -312,7 +312,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0)) - c, = argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])), blocks(S)) + c, = _argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])), blocks(S)) U, S, V, ϵ = tsvd(C; trunc=truncdim(2 * dim(c))) ΔU = TensorMap(randn, scalartype(U), space(U)) ΔS = TensorMap(randn, scalartype(S), space(S))