Skip to content

Commit

Permalink
Improve warnings in AD rules
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed May 8, 2024
1 parent 370dd92 commit f33e3c9
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions ext/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -385,14 +385,17 @@ function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector

# check whether cotangents arise from gauge-invariance objective function
mask = abs.(Sp' .- Sp) .< tol
gaugepart = view(aUΔU, mask) + view(aVΔV, mask)
norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice"
gaugepart = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf)
gaugepart < tol ||
@warn "`svd` cotangents sensitive to gauge choice: (|ΔGauge| = $gaugepart)"
if p > r
rprange = (r + 1):p
norm(view(aUΔU, rprange, rprange), Inf) < tol ||
@warn "cotangents sensitive to gauge choice"
norm(view(aVΔV, rprange, rprange), Inf) < tol ||
@warn "cotangents sensitive to gauge choice"
aUΔUpart_norm = norm(view(aUΔU, rprange, rprange), Inf)
aUΔUpart_norm < tol ||

Check warning on line 394 in ext/TensorKitChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt.jl#L393-L394

Added lines #L393 - L394 were not covered by tests
@warn "`svd` cotangents sensitive to gauge choice: (|aUΔU| = $aUΔUpart_norm)"
aVΔVpart_norm = norm(view(aVΔV, rprange, rprange), Inf)
aVΔVpart_norm < tol ||

Check warning on line 397 in ext/TensorKitChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt.jl#L396-L397

Added lines #L396 - L397 were not covered by tests
@warn "`svd` cotangents sensitive to gauge choice: (|aVΔV| = $aVΔVpart_norm)"
end

UdΔAV = (aUΔU .+ aVΔV) .* safe_inv.(Sp' .- Sp, tol) .+
Expand Down Expand Up @@ -461,8 +464,9 @@ function eig_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix
VdΔV = V' * ΔV

mask = abs.(transpose(D) .- D) .< tol
gaugepart = view(VdΔV, mask)
norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice"
gaugepart = norm(view(VdΔV, mask), Inf)
gaugepart < tol ||
@warn "`eig` cotangents sensitive to gauge choice: (|ΔGauge| = $gaugepart)"

VdΔV .*= conj.(safe_inv.(transpose(D) .- D, tol))

Expand Down Expand Up @@ -504,8 +508,9 @@ function eigh_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatri
aVdΔV = rmul!(VdΔV - VdΔV', 1 / 2)

mask = abs.(D' .- D) .< tol
gaugepart = view(aVdΔV, mask)
norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice"
gaugepart = norm(view(aVdΔV, mask))
gaugepart < tol ||
@warn "`eigh` cotangents sensitive to gauge choice: (|ΔGauge| = $gaugepart)"

aVdΔV .*= safe_inv.(D' .- D, tol)

Expand Down Expand Up @@ -567,8 +572,9 @@ function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix,
Q2 = view(Q, :, (p + 1):m)
ΔQ2 = view(ΔQ, :, (p + 1):m)
Q1dΔQ2 = Q1' * ΔQ2
gaugepart = mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1)
norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice"
gaugepart = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
gaugepart < tol ||

Check warning on line 576 in ext/TensorKitChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt.jl#L575-L576

Added lines #L575 - L576 were not covered by tests
@warn "`qr` cotangents sensitive to gauge choice: (|ΔGauge| = $gaugepart)"
mul!(ΔA1, Q2, Q1dΔQ2', -1, 1)
end
rdiv!(ΔA1, UpperTriangular(R11)')
Expand Down Expand Up @@ -621,8 +627,9 @@ function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix,
Q2 = view(Q, (p + 1):n, :)
ΔQ2 = view(ΔQ, (p + 1):n, :)
ΔQ2Q1d = ΔQ2 * Q1'
gaugepart = mul!(copy(ΔQ2), ΔQ2Q1d, Q1, -1, 1)
norm(gaugepart, Inf) < tol || @warn "cotangents sensitive to gauge choice"
gaugepart = norm(mul!(copy(ΔQ2), ΔQ2Q1d, Q1, -1, 1))
gaugepart < tol ||

Check warning on line 631 in ext/TensorKitChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt.jl#L630-L631

Added lines #L630 - L631 were not covered by tests
@warn "`lq` cotangents sensitive to gauge choice: (|ΔGauge| = $gaugepart)"
mul!(ΔA1, ΔQ2Q1d', Q2, -1, 1)
end
ldiv!(LowerTriangular(L11)', ΔA1)
Expand Down

0 comments on commit f33e3c9

Please sign in to comment.