Skip to content

Commit

Permalink
Changes for TensorOperations 4.0.6
Browse files Browse the repository at this point in the history
tensorscalar now has a `rrule`
  • Loading branch information
lkdvos committed Sep 29, 2023
1 parent 67a4453 commit 31b6523
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ TensorKitChainRulesCoreExt = "ChainRulesCore"
HalfIntegers = "1"
LRUCache = "1.0.2"
Strided = "2"
TensorOperations = "4.0.5"
TensorOperations = "4.0.6"
TupleTools = "1.1"
VectorInterface = "0.4"
WignerSymbols = "1,2"
Expand Down
5 changes: 0 additions & 5 deletions ext/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,6 @@ function ChainRulesCore.rrule(::typeof(permute), tsrc::AbstractTensorMap, p::Ind
return permute(tsrc, p), permute_pullback
end

function ChainRulesCore.rrule(::typeof(scalar), t::AbstractTensorMap)
scalar_pullback(Δc) = NoTangent(), fill!(similar(t), unthunk(Δc))
return scalar(t), scalar_pullback
end

# LinearAlgebra
# -------------

Expand Down
28 changes: 17 additions & 11 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ function _randomize!(a::TensorMap)
end

# Float32 and finite differences don't mix well
precision(::Type{<:Union{Float32, Complex{Float32}}}) = 1e-2
precision(::Type{<:Union{Float64, Complex{Float64}}}) = 1e-8
precision(::Type{<:Union{Float32,Complex{Float32}}}) = 1e-2
precision(::Type{<:Union{Float64,Complex{Float64}}}) = 1e-8

# rrules for functions that destroy inputs
# ----------------------------------------
Expand Down Expand Up @@ -162,36 +162,37 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
test_rrule(LinearAlgebra.adjoint, A)
test_rrule(LinearAlgebra.norm, A, 2)
end

@testset "TensorOperations ($T)" for T in (Float64, ComplexF64)
atol = precision(T)
rtol = precision(T)

@testset "tensortrace!" begin
A = TensorMap(randn, T, V[1] V[2] V[3] V[1] V[5])
pC = ((3, 5), (2,))
pA = ((1,), (4,))
α = randn(T)
β = randn(T)

C = _randomize!(TensorOperations.tensoralloc_add(T, pC, A, :N, false))
test_rrule(tensortrace!, C, pC, A, pA, :N, α, β; atol, rtol)

C = _randomize!(TensorOperations.tensoralloc_add(T, pC, A, :C, false))
test_rrule(tensortrace!, C, pC, A, pA, :C, α, β; atol, rtol)
end

@testset "tensoradd!" begin
p = ((1, 3, 2), (5, 4))
A = TensorMap(randn, T, V[1] V[2] V[3] V[4] V[5])
C = _randomize!(TensorOperations.tensoralloc_add(T, p, A, :N, false))
α = randn(T)
β = randn(T)
test_rrule(tensoradd!, C, p, A, :N, α, β; atol, rtol)

C = _randomize!(TensorOperations.tensoralloc_add(T, p, A, :C, false))
test_rrule(tensoradd!, C, p, A, :C, α, β; atol, rtol)
end

@testset "tensorcontract!" begin
A = TensorMap(randn, T, V[1] V[2] V[3] V[4] V[5])
B = TensorMap(randn, T, V[3] V[1]' V[2])
Expand All @@ -200,11 +201,11 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
pB = ((2, 1), (3,))
α = randn(T)
β = randn(T)
C = _randomize!(TensorOperations.tensoralloc_contract(T, pC, A, pA, :N,

C = _randomize!(TensorOperations.tensoralloc_contract(T, pC, A, pA, :N,
B, pB, :N, false))
test_rrule(tensorcontract!, C, pC, A, pA, :N, B, pB, :N, α, β; atol, rtol)

A2 = TensorMap(randn, T, V[1]' V[2]' V[3]' V[4]' V[5]')
C = _randomize!(TensorOperations.tensoralloc_contract(T, pC, A2, pA, :C,
B, pB, :N, false))
Expand All @@ -219,6 +220,11 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
B2, pB, :C, false))
test_rrule(tensorcontract!, C, pC, A2, pA, :C, B2, pB, :C, α, β; atol, rtol)
end

@testset "tensorscalar" begin
A = Tensor(randn, T, ProductSpace{typeof(V[1]),0}())
test_rrule(tensorscalar, A)
end
end

@testset "Factorizations ($T)" for T in (Float64, ComplexF64)
Expand Down

0 comments on commit 31b6523

Please sign in to comment.