From ba6e5581aa8669c0fe5a12d069ccc2664979c149 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Mon, 1 Jan 2024 16:55:22 +0100 Subject: [PATCH] define `argmax(f, domain)` for julia < v1.7 --- src/auxiliary/linalg.jl | 16 ++++++++++++---- test/ad.jl | 2 +- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/auxiliary/linalg.jl b/src/auxiliary/linalg.jl index bc57d4c6..815f9900 100644 --- a/src/auxiliary/linalg.jl +++ b/src/auxiliary/linalg.jl @@ -56,6 +56,14 @@ using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, checksquare using ..TensorKit: OrthogonalFactorizationAlgorithm, QL, QLpos, QR, QRpos, LQ, LQpos, RQ, RQpos, SVD, SDD, Polar +# only defined in >v1.7 +@static if VERSION < v"1.7-" + _rf_findmax((fm, im), (fx, ix)) = isless(fm, fx) ? (fx, ix) : (fm, im) + _argmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain)[2] +else + _argmax(f, domain) = argmax(f, domain) +end + # TODO: define for CuMatrix if we support this function one!(A::DenseMatrix) Threads.@threads for j in 1:size(A, 2) @@ -273,12 +281,12 @@ function eig!(A::StridedMatrix{T}; permute::Bool=true, scale::Bool=true) where { while j <= n if DI[j] == 0 vr = view(VR, :, j) - s = conj(sign(argmax(abs, vr))) + s = conj(sign(_argmax(abs, vr))) V[:, j] .= s .* vr else vr = view(VR, :, j) vi = view(VR, :, j + 1) - s = conj(sign(argmax(abs, vr))) # vectors coming from lapack have already real absmax component + s = conj(sign(_argmax(abs, vr))) # vectors coming from lapack have already real absmax component V[:, j] .= s .* (vr .+ im .* vi) V[:, j + 1] .= s .* (vr .- im .* vi) j += 1 @@ -296,7 +304,7 @@ function eig!(A::StridedMatrix{T}; permute::Bool=true, A)[[2, 4]] for j in 1:n v = view(V, :, j) - s = conj(sign(argmax(abs, v))) + s = conj(sign(_argmax(abs, v))) v .*= s end return D, V @@ -308,7 +316,7 @@ function eigh!(A::StridedMatrix{T}) where {T<:BlasFloat} D, V = LAPACK.syevr!('V', 'A', 'U', A, 0.0, 0.0, 0, 0, -1.0) for j in 1:n v = view(V, :, j) - s = conj(sign(argmax(abs, v))) + s = conj(sign(_argmax(abs, v))) v .*= s end return D, V diff --git a/test/ad.jl b/test/ad.jl index 3b82c0b4..2fec7f93 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -312,7 +312,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0)) - c, = argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])), blocks(S)) + c, = TensorKit.MatrixAlgebra._argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])), blocks(S)) U, S, V, ϵ = tsvd(C; trunc=truncdim(2 * dim(c))) ΔU = TensorMap(randn, scalartype(U), space(U)) ΔS = TensorMap(randn, scalartype(S), space(S))