Skip to content

Commit

Permalink
Add rrule ⊗
Browse files Browse the repository at this point in the history
fixes #88
  • Loading branch information
lkdvos committed Nov 20, 2023
1 parent a17bcbe commit 452b1e3
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
30 changes: 30 additions & 0 deletions ext/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,36 @@ function ChainRulesCore.rrule(::typeof(*), a::Number, b::AbstractTensorMap)
return a * b, times_pullback
end

function ChainRulesCore.rrule(::typeof(), A::AbstractTensorMap, B::AbstractTensorMap)
C = A B
projectA = ProjectTo(A)
projectB = ProjectTo(B)
function otimes_pullback(ΔC_)
ΔC = unthunk(ΔC_)
pΔC = ((codomainind(A)..., (domainind(A) .+ numout(B))...),
((codomainind(B) .+ numout(A))...,
(domainind(B) .+ (numin(A) + numout(A)))...))
dA_ = @thunk begin
ipA = (codomainind(A), domainind(A))
pB = (allind(B), ())
dA = zerovector(A,
TensorOperations.promote_contract(scalartype(ΔC),
scalartype(B)))
dA = tensorcontract!(dA, ipA, ΔC, pΔC, :N, B, pB, :C)
return projectA(dA)
end
dB_ = @thunk begin
ipB = (codomainind(B), domainind(B))
pA = ((), allind(A))
dB = zerovector(B, TensorOperations.promote_contract(scalartype(ΔC), scalartype(A)))
dB = tensorcontract!(dB, ipB, A, pA, :C, ΔC, pΔC, :N)
return projectB(dB)
end
return NoTangent(), dA_, dB_
end
return C, otimes_pullback
end

function ChainRulesCore.rrule(::typeof(permute), tsrc::AbstractTensorMap, p::Index2Tuple;
copy::Bool=false)
function permute_pullback(Δtdst)
Expand Down
4 changes: 4 additions & 0 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
test_rrule(*, A, C)

test_rrule(permute, A, ((1, 3, 2), (5, 4)))

D = TensorMap(randn, T, V[1] V[2] V[3])
E = TensorMap(randn, T, V[4] V[5])
test_rrule(, D, E)
end

@testset "Linear Algebra part II with scalartype $T" for T in (Float64, ComplexF64)
Expand Down

0 comments on commit 452b1e3

Please sign in to comment.