diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl index 47036534..6ff4b791 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -173,7 +173,7 @@ function ChainRulesCore.rrule(::typeof(norm), a::AbstractTensorMap, p::Real=2) p == 2 || error("currently only implemented for p = 2") n = norm(a, p) function norm_pullback(Δn) - return NoTangent(), a * (Δn' + Δn) / (n * 2 + eps(real(eltype(a)))), NoTangent() + return NoTangent(), a * (Δn' + Δn) / 2 / hypot(n, eps(one(n))), NoTangent() end return n, norm_pullback end