From 7ee775bb077d9418dc1988d5b873f1bf5a4e20e1 Mon Sep 17 00:00:00 2001 From: tangwei94 <34451674+tangwei94@users.noreply.github.com> Date: Mon, 13 May 2024 01:17:38 +0200 Subject: [PATCH] small fix in backward rule for leftorth, rightorth (#123) * small fix in backward rule for leftorth, rightorth * Apply suggestions from code review Co-authored-by: Lukas <37111893+lkdvos@users.noreply.github.com> --------- Co-authored-by: Lukas <37111893+lkdvos@users.noreply.github.com> --- ext/TensorKitChainRulesCoreExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl index b36976fd..0d1d0825 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -541,7 +541,7 @@ function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix, M = similar(R, (p, p)) ΔR isa AbstractZero || mul!(M, ΔR1, R1') - ΔQ isa AbstractZero || mul!(M, Q1', ΔQ1, -1, +1) + ΔQ isa AbstractZero || mul!(M, Q1', ΔQ1, -1, !(ΔR isa AbstractZero)) view(M, lowertriangularind(M)) .= conj.(view(M, uppertriangularind(M))) if eltype(M) <: Complex Md = view(M, diagind(M)) @@ -595,7 +595,7 @@ function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, M = similar(L, (p, p)) ΔL isa AbstractZero || mul!(M, L1', ΔL1) - ΔQ isa AbstractZero || mul!(M, ΔQ1, Q1', -1, +1) + ΔQ isa AbstractZero || mul!(M, ΔQ1, Q1', -1, !(ΔL isa AbstractZero)) view(M, uppertriangularind(M)) .= conj.(view(M, lowertriangularind(M))) if eltype(M) <: Complex Md = view(M, diagind(M))