From 25c06e54461edf265bac50186d8360f1ff804588 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Wed, 13 Sep 2023 16:44:12 +0200 Subject: [PATCH] Various AD bugfixes --- ext/TensorKitChainRulesCoreExt.jl | 20 ++++++++++---------- src/tensors/tensoroperations.jl | 8 ++++---- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl index 008468ac..920fcbca 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -80,16 +80,16 @@ function ChainRulesCore.rrule(::typeof(*), a::Number, b::AbstractTensorMap) return a * b, times_pullback end -function ChainRulesCore.rrule(::typeof(permute), t::AbstractTensorMap, p::Index2Tuple) - function permute_pullback(c) - invpt = _repartition(TupleTools.invperm(linearize(p)), t) - return NoTangent(), permute(c, invpt), NoTangent() +function ChainRulesCore.rrule(::typeof(permute), tsrc::AbstractTensorMap, p::Index2Tuple) + function permute_pullback(Δtdst) + invp = TensorKit._canonicalize(TupleTools.invperm(linearize(p)), tsrc) + return NoTangent(), permute(unthunk(Δtdst), invp), NoTangent() end - return permute(t, p), permute_pullback + return permute(tsrc, p), permute_pullback end function ChainRulesCore.rrule(::typeof(scalar), t::AbstractTensorMap) - scalar_pullback(Δc) = NoTangent(), fill!(similar(t), Δc) + scalar_pullback(Δc) = NoTangent(), fill!(similar(t), unthunk(Δc)) return scalar(t), scalar_pullback end @@ -102,7 +102,7 @@ function ChainRulesCore.rrule(::typeof(tr), A::AbstractTensorMap) end function ChainRulesCore.rrule(::typeof(adjoint), A::AbstractTensorMap) - adjoint_pullback(Δadjoint) = NoTangent(), adjoint(Δadjoint) + adjoint_pullback(Δadjoint) = NoTangent(), adjoint(unthunk(Δadjoint)) return adjoint(A), adjoint_pullback end @@ -124,7 +124,7 @@ end function ChainRulesCore.rrule(::typeof(TensorKit.tsvd), t::AbstractTensorMap; kwargs...) T = eltype(t) - U, S, V = tsvd(t; kwargs...) + U, S, V, ϵ = tsvd(t; kwargs...) F = similar(S) for (k, dst) in blocks(F) @@ -171,10 +171,10 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd), t::AbstractTensorMap; kw ∂t += U * pinv(S) * dV * (one(prv) - prv) end - return NoTangent(), ∂t, fill(NoTangent(), length(kwargs))... + return NoTangent(), ∂t end - return (U, S, V), tsvd_pullback + return (U, S, V, ϵ), tsvd_pullback end function _elementwise_mult(a::AbstractTensorMap, b::AbstractTensorMap) diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index 824baa54..1e3dfd42 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -26,10 +26,10 @@ function _canonicalize(p::Index2Tuple{N₁,N₂}, ::AbstractTensorMap{<:IndexSpace,N₁,N₂}) where {N₁,N₂} return p end -function _canonicalize(p::Index2Tuple, t::AbstractTensorMap) - p′ = linearize(p) - p₁ = TupleTools.getindices(p′, codomainind(t)) - p₂ = TupleTools.getindices(p′, domainind(t)) +_canonicalize(p::Index2Tuple, t::AbstractTensorMap) = _canonicalize(linearize(p), t) +function _canonicalize(p::IndexTuple, t::AbstractTensorMap) + p₁ = TupleTools.getindices(p, codomainind(t)) + p₂ = TupleTools.getindices(p, domainind(t)) return (p₁, p₂) end