From 4476595848d7d010a3252417090d52b42a6f5784 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Mon, 20 Nov 2023 17:42:59 +0100 Subject: [PATCH] =?UTF-8?q?Add=20rrule=20=E2=8A=97=20fixes=20#88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ext/TensorKitChainRulesCoreExt.jl | 32 +++++++++++++++++++++++++++++++ test/ad.jl | 4 ++++ 2 files changed, 36 insertions(+) diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl index 01e05545..5b7633f5 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -99,6 +99,38 @@ function ChainRulesCore.rrule(::typeof(*), a::Number, b::AbstractTensorMap) return a * b, times_pullback end +function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTensorMap) + C = A ⊗ B + projectA = ProjectTo(A) + projectB = ProjectTo(B) + function otimes_pullback(ΔC_) + ΔC = unthunk(ΔC_) + pΔC = ((codomainind(A)..., (domainind(A) .+ numout(B))...), + ((codomainind(B) .+ numout(A))..., + (domainind(B) .+ (numin(A) + numout(A)))...)) + dA_ = @thunk begin + ipA = (codomainind(A), domainind(A)) + pB = (allind(B), ()) + dA = zerovector(A, + TensorOperations.promote_contract(scalartype(ΔC), + scalartype(B))) + dA = tensorcontract!(dA, ipA, ΔC, pΔC, :N, B, pB, :C) + return projectA(dA) + end + dB_ = @thunk begin + ipB = (codomainind(B), domainind(B)) + pA = ((), allind(A)) + dB = zerovector(B, + TensorOperations.promote_contract(scalartype(ΔC), + scalartype(A))) + dB = tensorcontract!(dB, ipB, A, pA, :C, ΔC, pΔC, :N) + return projectB(dB) + end + return NoTangent(), dA_, dB_ + end + return C, otimes_pullback +end + function ChainRulesCore.rrule(::typeof(permute), tsrc::AbstractTensorMap, p::Index2Tuple; copy::Bool=false) function permute_pullback(Δtdst) diff --git a/test/ad.jl b/test/ad.jl index 8667364f..929c1b8d 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -127,6 +127,10 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), test_rrule(*, A, C) test_rrule(permute, A, ((1, 3, 2), (5, 4))) + + D = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3]) + E = TensorMap(randn, T, V[4] ← V[5]) + test_rrule(⊗, D, E) end @testset "Linear Algebra part II with scalartype $T" for T in (Float64, ComplexF64)