Skip to content

Commit

Permalink
Fix missing kwarg in rrule of permute`
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Oct 17, 2023
1 parent 484a323 commit e912eae
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions ext/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,13 @@ function ChainRulesCore.rrule(::typeof(*), a::Number, b::AbstractTensorMap)
return a * b, times_pullback
end

function ChainRulesCore.rrule(::typeof(permute), tsrc::AbstractTensorMap, p::Index2Tuple)
function ChainRulesCore.rrule(::typeof(permute), tsrc::AbstractTensorMap, p::Index2Tuple;
copy::Bool=false)
function permute_pullback(Δtdst)
invp = TensorKit._canonicalize(TupleTools.invperm(linearize(p)), tsrc)
return NoTangent(), permute(unthunk(Δtdst), invp), NoTangent()
return NoTangent(), permute(unthunk(Δtdst), invp; copy=true), NoTangent()
end
return permute(tsrc, p), permute_pullback
return permute(tsrc, p; copy=true), permute_pullback
end

# LinearAlgebra
Expand Down

0 comments on commit e912eae

Please sign in to comment.