diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl index b36976fd..0d1d0825 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -541,7 +541,7 @@ function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix, M = similar(R, (p, p)) ΔR isa AbstractZero || mul!(M, ΔR1, R1') - ΔQ isa AbstractZero || mul!(M, Q1', ΔQ1, -1, +1) + ΔQ isa AbstractZero || mul!(M, Q1', ΔQ1, -1, !(ΔR isa AbstractZero)) view(M, lowertriangularind(M)) .= conj.(view(M, uppertriangularind(M))) if eltype(M) <: Complex Md = view(M, diagind(M)) @@ -595,7 +595,7 @@ function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, M = similar(L, (p, p)) ΔL isa AbstractZero || mul!(M, L1', ΔL1) - ΔQ isa AbstractZero || mul!(M, ΔQ1, Q1', -1, +1) + ΔQ isa AbstractZero || mul!(M, ΔQ1, Q1', -1, !(ΔL isa AbstractZero)) view(M, uppertriangularind(M)) .= conj.(view(M, lowertriangularind(M))) if eltype(M) <: Complex Md = view(M, diagind(M))