diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl index 514d4749..6ff4b791 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -172,7 +172,9 @@ end function ChainRulesCore.rrule(::typeof(norm), a::AbstractTensorMap, p::Real=2) p == 2 || error("currently only implemented for p = 2") n = norm(a, p) - norm_pullback(Δn) = NoTangent(), a * (Δn' + Δn) / (n * 2), NoTangent() + function norm_pullback(Δn) + return NoTangent(), a * (Δn' + Δn) / 2 / hypot(n, eps(one(n))), NoTangent() + end return n, norm_pullback end