From f33e3c9fd044301472063a580994f5300a3a1520 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Wed, 8 May 2024 08:34:29 +0200 Subject: [PATCH] Improve warnings in AD rules --- ext/TensorKitChainRulesCoreExt.jl | 35 ++++++++++++++++++------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl index b36976fd..8c8ed7e6 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -385,14 +385,17 @@ function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector # check whether cotangents arise from gauge-invariance objective function mask = abs.(Sp' .- Sp) .< tol - gaugepart = view(aUΔU, mask) + view(aVΔV, mask) - norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice" + gaugepart = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf) + gaugepart < tol || + @warn "`svd` cotangents sensitive to gauge choice: (|ΔGauge| = $gaugepart)" if p > r rprange = (r + 1):p - norm(view(aUΔU, rprange, rprange), Inf) < tol || - @warn "cotangents sensitive to gauge choice" - norm(view(aVΔV, rprange, rprange), Inf) < tol || - @warn "cotangents sensitive to gauge choice" + aUΔUpart_norm = norm(view(aUΔU, rprange, rprange), Inf) + aUΔUpart_norm < tol || + @warn "`svd` cotangents sensitive to gauge choice: (|aUΔU| = $aUΔUpart_norm)" + aVΔVpart_norm = norm(view(aVΔV, rprange, rprange), Inf) + aVΔVpart_norm < tol || + @warn "`svd` cotangents sensitive to gauge choice: (|aVΔV| = $aVΔVpart_norm)" end UdΔAV = (aUΔU .+ aVΔV) .* safe_inv.(Sp' .- Sp, tol) .+ @@ -461,8 +464,9 @@ function eig_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix VdΔV = V' * ΔV mask = abs.(transpose(D) .- D) .< tol - gaugepart = view(VdΔV, mask) - norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice" + gaugepart = norm(view(VdΔV, mask), Inf) + gaugepart < tol || + @warn "`eig` cotangents sensitive to gauge choice: (|ΔGauge| = $gaugepart)" VdΔV .*= conj.(safe_inv.(transpose(D) .- D, tol)) @@ -504,8 +508,9 @@ function eigh_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatri aVdΔV = rmul!(VdΔV - VdΔV', 1 / 2) mask = abs.(D' .- D) .< tol - gaugepart = view(aVdΔV, mask) - norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice" + gaugepart = norm(view(aVdΔV, mask)) + gaugepart < tol || + @warn "`eigh` cotangents sensitive to gauge choice: (|ΔGauge| = $gaugepart)" aVdΔV .*= safe_inv.(D' .- D, tol) @@ -567,8 +572,9 @@ function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix, Q2 = view(Q, :, (p + 1):m) ΔQ2 = view(ΔQ, :, (p + 1):m) Q1dΔQ2 = Q1' * ΔQ2 - gaugepart = mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1) - norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice" + gaugepart = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf) + gaugepart < tol || + @warn "`qr` cotangents sensitive to gauge choice: (|ΔGauge| = $gaugepart)" mul!(ΔA1, Q2, Q1dΔQ2', -1, 1) end rdiv!(ΔA1, UpperTriangular(R11)') @@ -621,8 +627,9 @@ function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, Q2 = view(Q, (p + 1):n, :) ΔQ2 = view(ΔQ, (p + 1):n, :) ΔQ2Q1d = ΔQ2 * Q1' - gaugepart = mul!(copy(ΔQ2), ΔQ2Q1d, Q1, -1, 1) - norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice" + gaugepart = norm(mul!(copy(ΔQ2), ΔQ2Q1d, Q1, -1, 1)) + gaugepart < tol || + @warn "`lq` cotangents sensitive to gauge choice: (|ΔGauge| = $gaugepart)" mul!(ΔA1, ΔQ2Q1d', Q2, -1, 1) end ldiv!(LowerTriangular(L11)', ΔA1)