Skip to content

Commit

Permalink
and one more bug in svd rrule
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Nov 17, 2023
1 parent 9b08717 commit 35a582f
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion ext/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ function svd_pullback(U::AbstractMatrix, S::AbstractVector, Vd::AbstractMatrix,
ΔVd -= VrΔV' * Vr' # subtract this part from ΔV
end
else
VrΔV = fill!(similar(V, (r - p, p)), 0)
VrΔV = fill!(similar(Vd, (r - p, p)), 0)
end

X = (1 // 2) .* ((UrΔU .+ VrΔV) .* safe_inv.(Sp' .- Sr, tol) .+
Expand Down
4 changes: 2 additions & 2 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),

allS = mapreduce(x -> diag(x[2]), vcat, blocks(S))
truncval = (maximum(allS) + minimum(allS)) / 2
U, S, V, ϵ = tsvd(A; trunc=truncbelow(truncval))
U, S, V, ϵ = tsvd(A; trunc=truncerr(truncval))
ΔU = TensorMap(randn, scalartype(U), space(U))
ΔS = TensorMap(randn, scalartype(S), space(S))
ΔV = TensorMap(randn, scalartype(V), space(V))
Expand All @@ -246,7 +246,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
end
end
test_rrule(tsvd, A; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0),
fkwargs=(; trunc=truncbelow(truncval)))
fkwargs=(; trunc=truncerr(truncval)))
end

let (U, S, V, ϵ) = tsvd(B)
Expand Down

0 comments on commit 35a582f

Please sign in to comment.