From 76078de3a919ac1491d7f4578652ab4dfc5b97cb Mon Sep 17 00:00:00 2001 From: lkdvos Date: Fri, 28 Jun 2024 11:13:47 +0200 Subject: [PATCH] Restrict convert rrule to trivialtensormap --- ext/TensorKitChainRulesCoreExt/constructors.jl | 4 +++- test/ad.jl | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/constructors.jl b/ext/TensorKitChainRulesCoreExt/constructors.jl index 49d3f7ba..d63c35aa 100644 --- a/ext/TensorKitChainRulesCoreExt/constructors.jl +++ b/ext/TensorKitChainRulesCoreExt/constructors.jl @@ -17,8 +17,10 @@ function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap) return copy(t), copy_pullback end +# this rule does not work for generic symmetries, as we currently have no way to +# project back onto the symmetric subspace function ChainRulesCore.rrule(::typeof(Base.convert), T::Type{<:Array}, - t::AbstractTensorMap) + t::TrivialTensorMap) A = convert(T, t) function convert_pullback(ΔA) ∂t = TensorMap(unthunk(ΔA), codomain(t), domain(t)) diff --git a/test/ad.jl b/test/ad.jl index b4679972..d2d67b3b 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -143,8 +143,10 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), test_rrule(copy, T1) test_rrule(copy, T2) - test_rrule(convert, Array, T1) - test_rrule(TensorMap, convert(Array, T1), space(T1)) + T1 isa TrivialTensorMap && test_rrule(convert, Array, T1) + T2 isa TrivialTensorMap && test_rrule(TensorMap, convert(Array, T1), space(T1)) + # TODO: can we make these methods/tests work for generic symmetries? + # the main problem here is finitedifferencing generates non-symmetric entries end @timedtestset "Basic Linear Algebra with scalartype $T" for T in (Float64, ComplexF64)