Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ChainRulesCoreExt into separate files #133

Merged
merged 13 commits into from
Jul 2, 2024
4 changes: 3 additions & 1 deletion ext/TensorKitChainRulesCoreExt/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
return copy(t), copy_pullback
end

# this rule does not work for generic symmetries, as we currently have no way to
# project back onto the symmetric subspace
lkdvos marked this conversation as resolved.
Show resolved Hide resolved
function ChainRulesCore.rrule(::typeof(Base.convert), T::Type{<:Array},
t::AbstractTensorMap)
t::TrivialTensorMap)
A = convert(T, t)
function convert_pullback(ΔA)
∂t = TensorMap(unthunk(ΔA), codomain(t), domain(t))
Expand All @@ -27,21 +29,21 @@
return A, convert_pullback
end

function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{Dict}, t::AbstractTensorMap)
out = convert(Dict, t)
function convert_pullback(c)
if haskey(c, :data) # :data is the only thing for which this dual makes sense
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about dual = typeof(out)(:data => c[:data]) ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that works, because then the spaces are missing in the dictionary-to-tensormap converter. Maybe the comment is a bit misleading -- all fields in the dictionary are required, but Zygote tends to drop fields that do not contribute (i.e. codomain and domain). Thus, this uses the dictionary from the forwards pass with the data from the backwards pass.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I wanted to avoid copying the data from out, but probably copy(out) is a shallow copy that does not duplicate out[:data]?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, it should really only duplicate the pointer, which I think is acceptable. On top of that, I am not sure this rule is ever used/useful anyways. I think I copied this at some point from maarten, but presumably that was also only introduced for testing purposes, as it also suffers from the weird interplay of "inner product on the parameters" for non-abelian symmetries.

dual = copy(out)
dual[:data] = c[:data]
return (NoTangent(), NoTangent(), convert(TensorMap, dual))

Check warning on line 38 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L32-L38

Added lines #L32 - L38 were not covered by tests
else
# instead of zero(t) you can also return ZeroTangent(), which is type unstable
return (NoTangent(), NoTangent(), zero(t))

Check warning on line 41 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L41

Added line #L41 was not covered by tests
end
end
return out, convert_pullback

Check warning on line 44 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L44

Added line #L44 was not covered by tests
end
function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{TensorMap},

Check warning on line 46 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L46

Added line #L46 was not covered by tests
t::Dict{Symbol,Any})
return convert(TensorMap, t), v -> (NoTangent(), NoTangent(), convert(Dict, v))

Check warning on line 48 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L48

Added line #L48 was not covered by tests
end
6 changes: 4 additions & 2 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,10 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
test_rrule(copy, T1)
test_rrule(copy, T2)

test_rrule(convert, Array, T1)
test_rrule(TensorMap, convert(Array, T1), space(T1))
T1 isa TrivialTensorMap && test_rrule(convert, Array, T1)
T2 isa TrivialTensorMap && test_rrule(TensorMap, convert(Array, T1), space(T1))
# TODO: can we make these methods/tests work for generic symmetries?
# the main problem here is finitedifferencing generates non-symmetric entries
end

@timedtestset "Basic Linear Algebra with scalartype $T" for T in (Float64, ComplexF64)
Expand Down
Loading