From 3d56e50a617703cff10ca834ccdf56361a5a68a1 Mon Sep 17 00:00:00 2001 From: Wei Tang Date: Sun, 12 May 2024 19:23:27 +0200 Subject: [PATCH 1/2] small fix in backward rule for leftorth, rightorth --- ext/TensorKitChainRulesCoreExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl index b36976fd..3a3c354b 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -539,7 +539,7 @@ function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix, ΔR1 = view(ΔR, 1:p, :) ΔR11 = view(ΔR, 1:p, 1:p) - M = similar(R, (p, p)) + M = zeros(eltype(R), (p, p)) ΔR isa AbstractZero || mul!(M, ΔR1, R1') ΔQ isa AbstractZero || mul!(M, Q1', ΔQ1, -1, +1) view(M, lowertriangularind(M)) .= conj.(view(M, uppertriangularind(M))) @@ -593,7 +593,7 @@ function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, ΔL1 = view(ΔL, :, 1:p) ΔR11 = view(ΔL, 1:p, 1:p) - M = similar(L, (p, p)) + M = zeros(eltype(L), (p, p)) ΔL isa AbstractZero || mul!(M, L1', ΔL1) ΔQ isa AbstractZero || mul!(M, ΔQ1, Q1', -1, +1) view(M, uppertriangularind(M)) .= conj.(view(M, lowertriangularind(M))) From 225a423404f985f8d1d3f97431101ec69a13c68c Mon Sep 17 00:00:00 2001 From: tangwei94 <34451674+tangwei94@users.noreply.github.com> Date: Sun, 12 May 2024 22:35:02 +0200 Subject: [PATCH 2/2] Apply suggestions from code review Co-authored-by: Lukas <37111893+lkdvos@users.noreply.github.com> --- ext/TensorKitChainRulesCoreExt.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl index 3a3c354b..0d1d0825 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -539,9 +539,9 @@ function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix, ΔR1 = view(ΔR, 1:p, :) ΔR11 = view(ΔR, 1:p, 1:p) - M = zeros(eltype(R), (p, p)) + M = similar(R, (p, p)) ΔR isa AbstractZero || mul!(M, ΔR1, R1') - ΔQ isa AbstractZero || mul!(M, Q1', ΔQ1, -1, +1) + ΔQ isa AbstractZero || mul!(M, Q1', ΔQ1, -1, !(ΔR isa AbstractZero)) view(M, lowertriangularind(M)) .= conj.(view(M, uppertriangularind(M))) if eltype(M) <: Complex Md = view(M, diagind(M)) @@ -593,9 +593,9 @@ function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, ΔL1 = view(ΔL, :, 1:p) ΔR11 = view(ΔL, 1:p, 1:p) - M = zeros(eltype(L), (p, p)) + M = similar(L, (p, p)) ΔL isa AbstractZero || mul!(M, L1', ΔL1) - ΔQ isa AbstractZero || mul!(M, ΔQ1, Q1', -1, +1) + ΔQ isa AbstractZero || mul!(M, ΔQ1, Q1', -1, !(ΔL isa AbstractZero)) view(M, uppertriangularind(M)) .= conj.(view(M, lowertriangularind(M))) if eltype(M) <: Complex Md = view(M, diagind(M))