From e912eae099fbc7313ef7f8e9fbd8f59b50dbf6d2 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Tue, 17 Oct 2023 19:48:05 +0200 Subject: [PATCH] Fix missing kwarg in rrule `of `permute` --- ext/TensorKitChainRulesCoreExt.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl index 6abd207f..bf7c5526 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -98,12 +98,13 @@ function ChainRulesCore.rrule(::typeof(*), a::Number, b::AbstractTensorMap) return a * b, times_pullback end -function ChainRulesCore.rrule(::typeof(permute), tsrc::AbstractTensorMap, p::Index2Tuple) +function ChainRulesCore.rrule(::typeof(permute), tsrc::AbstractTensorMap, p::Index2Tuple; + copy::Bool=false) function permute_pullback(Δtdst) invp = TensorKit._canonicalize(TupleTools.invperm(linearize(p)), tsrc) - return NoTangent(), permute(unthunk(Δtdst), invp), NoTangent() + return NoTangent(), permute(unthunk(Δtdst), invp; copy=true), NoTangent() end - return permute(tsrc, p), permute_pullback + return permute(tsrc, p; copy=true), permute_pullback end # LinearAlgebra