From 225a423404f985f8d1d3f97431101ec69a13c68c Mon Sep 17 00:00:00 2001 From: tangwei94 <34451674+tangwei94@users.noreply.github.com> Date: Sun, 12 May 2024 22:35:02 +0200 Subject: [PATCH] Apply suggestions from code review Co-authored-by: Lukas <37111893+lkdvos@users.noreply.github.com> --- ext/TensorKitChainRulesCoreExt.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl index 3a3c354b..0d1d0825 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -539,9 +539,9 @@ function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix, ΔR1 = view(ΔR, 1:p, :) ΔR11 = view(ΔR, 1:p, 1:p) - M = zeros(eltype(R), (p, p)) + M = similar(R, (p, p)) ΔR isa AbstractZero || mul!(M, ΔR1, R1') - ΔQ isa AbstractZero || mul!(M, Q1', ΔQ1, -1, +1) + ΔQ isa AbstractZero || mul!(M, Q1', ΔQ1, -1, !(ΔR isa AbstractZero)) view(M, lowertriangularind(M)) .= conj.(view(M, uppertriangularind(M))) if eltype(M) <: Complex Md = view(M, diagind(M)) @@ -593,9 +593,9 @@ function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, ΔL1 = view(ΔL, :, 1:p) ΔR11 = view(ΔL, 1:p, 1:p) - M = zeros(eltype(L), (p, p)) + M = similar(L, (p, p)) ΔL isa AbstractZero || mul!(M, L1', ΔL1) - ΔQ isa AbstractZero || mul!(M, ΔQ1, Q1', -1, +1) + ΔQ isa AbstractZero || mul!(M, ΔQ1, Q1', -1, !(ΔL isa AbstractZero)) view(M, uppertriangularind(M)) .= conj.(view(M, lowertriangularind(M))) if eltype(M) <: Complex Md = view(M, diagind(M))