Skip to content

Commit

Permalink
Improve tensortrace rrule tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed May 29, 2024
1 parent 64a6f40 commit c3dc374
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,12 @@ end
precision(::Type{<:Union{Float32,Complex{Float32}}}) = 1e-2
precision(::Type{<:Union{Float64,Complex{Float64}}}) = 1e-6

function randindextuple(N::Int)
k = rand(0:N)
function randindextuple(N::Int, k::Int=rand(0:N))
@assert 0 k N
_p = randperm(N)
return (tuple(_p[1:k]...), tuple(_p[(k + 1):end]...))
end

# rrules for functions that destroy inputs
# ----------------------------------------
function ChainRulesCore.rrule(::typeof(TensorKit.tsvd), args...; kwargs...)
Expand Down Expand Up @@ -174,17 +175,25 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
rtol = precision(T)

@testset "tensortrace!" begin
A = TensorMap(randn, T, V[1] V[2] V[3] V[1] V[5])
pC = ((3, 5), (2,))
pA = ((1,), (4,))
α = randn(T)
β = randn(T)
for _ in 1:5
k1 = rand(1:3)
k2 = k1 == 3 ? 1 : rand(1:2)
V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1))
V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2))

C = _randomize!(TensorOperations.tensoralloc_add(T, pC, A, :N, false))
test_rrule(tensortrace!, C, pC, A, pA, :N, α, β; atol, rtol)
(_p, _q) = randindextuple(k1 + 2 * k2, k1)
p = _repartition(_p, rand(0:k1))
q = _repartition(_q, k2)
ip = _repartition(invperm(linearize((_p, _q))), rand(0:(k1 + 2 * k2)))
A = TensorMap(randn, T, permute(prod(V1) prod(V2) prod(V2), ip))

C = _randomize!(TensorOperations.tensoralloc_add(T, pC, A, :C, false))
test_rrule(tensortrace!, C, pC, A, pA, :C, α, β; atol, rtol)
α = randn(T)
β = randn(T)
for conjA in (:N, :C)
C = _randomize!(TensorOperations.tensoralloc_add(T, p, A, conjA, false))
test_rrule(tensortrace!, C, p, A, q, conjA, α, β; atol, rtol)
end
end
end

@testset "tensoradd!" begin
Expand Down

0 comments on commit c3dc374

Please sign in to comment.