Skip to content

Commit

Permalink
Refactor ChainRulesCoreExt into separate files
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Jun 27, 2024
1 parent 8b41fe3 commit fcad6a9
Show file tree
Hide file tree
Showing 6 changed files with 343 additions and 351 deletions.
28 changes: 28 additions & 0 deletions ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
module TensorKitChainRulesCoreExt

using TensorOperations
using VectorInterface
using TensorKit
using ChainRulesCore
using LinearAlgebra
using TupleTools

import TensorOperations as TO
using TensorOperations: Backend, promote_contract
using VectorInterface: promote_scale, promote_add

ext = @static if isdefined(Base, :get_extension)
Base.get_extension(TensorOperations, :TensorOperationsChainRulesCoreExt)
else
TensorOperations.TensorOperationsChainRulesCoreExt
end
const _conj = ext._conj
const trivtuple = ext.trivtuple

include("utility.jl")
include("constructors.jl")
include("linalg.jl")
include("tensoroperations.jl")
include("factorizations.jl")

end
47 changes: 47 additions & 0 deletions ext/TensorKitChainRulesCoreExt/constructors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
@non_differentiable TensorKit.TensorMap(f::Function, storagetype, cod, dom)
@non_differentiable TensorKit.id(args...)
@non_differentiable TensorKit.isomorphism(args...)
@non_differentiable TensorKit.isometry(args...)
@non_differentiable TensorKit.unitary(args...)

function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...)
function TensorMap_pullback(Δt)
∂d = convert(Array, Δt)
return NoTangent(), ∂d, fill(NoTangent(), length(args))...

Check warning on line 10 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L7-L10

Added lines #L7 - L10 were not covered by tests
end
return TensorMap(d, args...), TensorMap_pullback

Check warning on line 12 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L12

Added line #L12 was not covered by tests
end

function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap)
copy_pullback(Δt) = NoTangent(), Δt
return copy(t), copy_pullback

Check warning on line 17 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L15-L17

Added lines #L15 - L17 were not covered by tests
end

function ChainRulesCore.rrule(::typeof(Base.convert), T::Type{<:Array},

Check warning on line 20 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L20

Added line #L20 was not covered by tests
t::AbstractTensorMap)
A = convert(T, t)
function convert_pullback(ΔA)
∂t = TensorMap(ΔA, codomain(t), domain(t))
return NoTangent(), NoTangent(), ∂t

Check warning on line 25 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L22-L25

Added lines #L22 - L25 were not covered by tests
end
return A, convert_pullback

Check warning on line 27 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L27

Added line #L27 was not covered by tests
end

function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{Dict}, t::AbstractTensorMap)
out = convert(Dict, t)
function convert_pullback(c)
if haskey(c, :data) # :data is the only thing for which this dual makes sense
dual = copy(out)
dual[:data] = c[:data]
return (NoTangent(), NoTangent(), convert(TensorMap, dual))

Check warning on line 36 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L30-L36

Added lines #L30 - L36 were not covered by tests
else
# instead of zero(t) you can also return ZeroTangent(), which is type unstable
return (NoTangent(), NoTangent(), zero(t))

Check warning on line 39 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L39

Added line #L39 was not covered by tests
end
end
return out, convert_pullback

Check warning on line 42 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L42

Added line #L42 was not covered by tests
end
function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{TensorMap},

Check warning on line 44 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L44

Added line #L44 was not covered by tests
t::Dict{Symbol,Any})
return convert(TensorMap, t), v -> (NoTangent(), NoTangent(), convert(Dict, v))

Check warning on line 46 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L46

Added line #L46 was not covered by tests
end
Loading

0 comments on commit fcad6a9

Please sign in to comment.