diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl index bf7c5526..a5342483 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -178,7 +178,7 @@ function svd_rev(U::AbstractMatrix, S::AbstractMatrix, V::AbstractMatrix, ΔU, F = _invert_S²(S, tol) S⁻¹ = pinv(S; atol=tol) - term = Diagonal(diag(ΔS)) + term = ΔS isa ZeroTangent ? ΔS : Diagonal(diag(ΔS)) J = F .* (U' * ΔU) term += (J + J') * S