Skip to content

Commit

Permalink
Add and test rrules for real and imag (#183)
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos authored Dec 6, 2024
1 parent 6387f26 commit 40e74d7
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
16 changes: 16 additions & 0 deletions ext/TensorKitChainRulesCoreExt/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,19 @@ function ChainRulesCore.rrule(::typeof(norm), a::AbstractTensorMap, p::Real=2)
end
return n, norm_pullback
end

function ChainRulesCore.rrule(::typeof(real), a::AbstractTensorMap)
a_real = real(a)
real_pullback(Δa) = NoTangent(), eltype(a) <: Real ? Δa : complex(unthunk(Δa))
return a_real, real_pullback
end

function ChainRulesCore.rrule(::typeof(imag), a::AbstractTensorMap)
a_imag = imag(a)
function imag_pullback(Δa)
Δa′ = unthunk(Δa)
return NoTangent(),
eltype(a) <: Real ? ZeroTangent() : complex(zerovector(Δa′), Δa′)
end
return a_imag, imag_pullback
end
3 changes: 3 additions & 0 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
A = randn(T, V[1] V[2] V[3] V[4] V[5])
B = randn(T, space(A))

test_rrule(real, A)
test_rrule(imag, A)

test_rrule(+, A, B)
test_rrule(-, A)
test_rrule(-, A, B)
Expand Down

0 comments on commit 40e74d7

Please sign in to comment.