Skip to content

Commit

Permalink
Restrict convert rrule to trivialtensormap
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Jun 28, 2024
1 parent c15b7ab commit 76078de
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
4 changes: 3 additions & 1 deletion ext/TensorKitChainRulesCoreExt/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap)
return copy(t), copy_pullback
end

# this rule does not work for generic symmetries, as we currently have no way to
# project back onto the symmetric subspace
function ChainRulesCore.rrule(::typeof(Base.convert), T::Type{<:Array},
t::AbstractTensorMap)
t::TrivialTensorMap)
A = convert(T, t)
function convert_pullback(ΔA)
∂t = TensorMap(unthunk(ΔA), codomain(t), domain(t))
Expand Down
6 changes: 4 additions & 2 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,10 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
test_rrule(copy, T1)
test_rrule(copy, T2)

test_rrule(convert, Array, T1)
test_rrule(TensorMap, convert(Array, T1), space(T1))
T1 isa TrivialTensorMap && test_rrule(convert, Array, T1)
T2 isa TrivialTensorMap && test_rrule(TensorMap, convert(Array, T1), space(T1))
# TODO: can we make these methods/tests work for generic symmetries?
# the main problem here is finitedifferencing generates non-symmetric entries
end

@timedtestset "Basic Linear Algebra with scalartype $T" for T in (Float64, ComplexF64)
Expand Down

0 comments on commit 76078de

Please sign in to comment.