diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl index 27208c3a..61f4d849 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -1,17 +1,27 @@ module TensorKitChainRulesCoreExt using TensorOperations +using VectorInterface using TensorKit using ChainRulesCore using LinearAlgebra using TupleTools +import TensorOperations as TO +using TensorOperations: Backend, promote_contract +using VectorInterface: promote_scale, promote_add + +ext = @static if isdefined(Base, :get_extension) + Base.get_extension(TensorOperations, :TensorOperationsChainRulesCoreExt) +else + TensorOperations.TensorOperationsChainRulesCoreExt +end +const _conj = ext._conj +const trivtuple = ext.trivtuple + # Utility # ------- -_conj(conjA::Symbol) = conjA == :C ? :N : :C -trivtuple(N) = ntuple(identity, N) - function _repartition(p::IndexTuple, N₁::Int) length(p) >= N₁ || throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) @@ -104,6 +114,8 @@ function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTe projectA = ProjectTo(A) projectB = ProjectTo(B) function otimes_pullback(ΔC_) + # TODO: this rule is probably better written in terms of inner products, + # using planarcontract and adjoint tensormaps would remove the twists. ΔC = unthunk(ΔC_) pΔC = ((codomainind(A)..., (domainind(A) .+ numout(B))...), ((codomainind(B) .+ numout(A))..., @@ -111,19 +123,17 @@ function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTe dA_ = @thunk begin ipA = (codomainind(A), domainind(A)) pB = (allind(B), ()) - dA = zerovector(A, - TensorOperations.promote_contract(scalartype(ΔC), - scalartype(B))) - dA = tensorcontract!(dA, ipA, ΔC, pΔC, :N, B, pB, :C) + dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B))) + tB = twist(B, filter(x -> isdual(space(B, x)), allind(B))) + dA = tensorcontract!(dA, ipA, ΔC, pΔC, :N, tB, pB, :C) return projectA(dA) end dB_ = @thunk begin ipB = (codomainind(B), domainind(B)) pA = ((), allind(A)) - dB = zerovector(B, - TensorOperations.promote_contract(scalartype(ΔC), - scalartype(A))) - dB = tensorcontract!(dB, ipB, A, pA, :C, ΔC, pΔC, :N) + dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A))) + tA = twist(A, filter(x -> isdual(space(A, x)), allind(A))) + dB = tensorcontract!(dB, ipB, tA, pA, :C, ΔC, pΔC, :N) return projectB(dB) end return NoTangent(), dA_, dB_ @@ -653,4 +663,150 @@ function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{TensorMap}, return convert(TensorMap, t), v -> (NoTangent(), NoTangent(), convert(Dict, v)) end +function ChainRulesCore.rrule(::typeof(TO.tensorcontract!), + C::AbstractTensorMap{S}, pC::Index2Tuple, + A::AbstractTensorMap{S}, pA::Index2Tuple, conjA::Symbol, + B::AbstractTensorMap{S}, pB::Index2Tuple, conjB::Symbol, + α::Number, β::Number, + backend::Backend...) where {S} + C′ = tensorcontract!(copy(C), pC, A, pA, conjA, B, pB, conjB, α, β, backend...) + + projectA = ProjectTo(A) + projectB = ProjectTo(B) + projectC = ProjectTo(C) + projectα = ProjectTo(α) + projectβ = ProjectTo(β) + + function pullback(ΔC′) + ΔC = unthunk(ΔC′) + ipC = invperm(linearize(pC)) + pΔC = (TupleTools.getindices(ipC, trivtuple(TO.numout(pA))), + TupleTools.getindices(ipC, TO.numout(pA) .+ trivtuple(TO.numin(pB)))) + + dC = @thunk projectC(scale(ΔC, conj(β))) + dA = @thunk begin + ipA = (invperm(linearize(pA)), ()) + conjΔC = conjA == :C ? :C : :N + conjB′ = conjA == :C ? conjB : _conj(conjB) + _dA = zerovector(A, + promote_contract(scalartype(ΔC), scalartype(B), scalartype(α))) + tB = twist(B, + TupleTools.vcat(filter(x -> !isdual(space(B, x)), pB[1]), + filter(x -> isdual(space(B, x)), pB[2]))) + _dA = tensorcontract!(_dA, ipA, + ΔC, pΔC, conjΔC, + tB, reverse(pB), conjB′, + conjA == :C ? α : conj(α), Zero(), backend...) + return projectA(_dA) + end + dB = @thunk begin + ipB = (invperm(linearize(pB)), ()) + conjΔC = conjB == :C ? :C : :N + conjA′ = conjB == :C ? conjA : _conj(conjA) + _dB = zerovector(B, + promote_contract(scalartype(ΔC), scalartype(A), scalartype(α))) + tA = twist(A, + TupleTools.vcat(filter(x -> isdual(space(A, x)), pA[1]), + filter(x -> !isdual(space(A, x)), pA[2]))) + _dB = tensorcontract!(_dB, ipB, + tA, reverse(pA), conjA′, + ΔC, pΔC, conjΔC, + conjB == :C ? α : conj(α), Zero(), backend...) + return projectB(_dB) + end + dα = @thunk begin + # TODO: this result should be AB = (C′ - βC) / α as C′ = βC + αAB + AB = tensorcontract(pC, A, pA, conjA, B, pB, conjB) + return projectα(inner(AB, ΔC)) + end + dβ = @thunk projectβ(inner(C, ΔC)) + dbackend = map(x -> NoTangent(), backend) + return NoTangent(), dC, NoTangent(), + dA, NoTangent(), NoTangent(), dB, NoTangent(), NoTangent(), dα, dβ, + dbackend... + end + return C′, pullback +end + +function ChainRulesCore.rrule(::typeof(TO.tensoradd!), + C::AbstractTensorMap{S}, pC::Index2Tuple, + A::AbstractTensorMap{S}, conjA::Symbol, + α::Number, β::Number, backend::Backend...) where {S} + C′ = tensoradd!(copy(C), pC, A, conjA, α, β, backend...) + + projectA = ProjectTo(A) + projectC = ProjectTo(C) + projectα = ProjectTo(α) + projectβ = ProjectTo(β) + + function pullback(ΔC′) + ΔC = unthunk(ΔC′) + dC = @thunk projectC(scale(ΔC, conj(β))) + dA = @thunk begin + ipC = invperm(linearize(pC)) + _dA = zerovector(A, promote_add(ΔC, α)) + _dA = tensoradd!(_dA, (ipC, ()), ΔC, conjA, conjA == :N ? conj(α) : α, Zero(), + backend...) + return projectA(_dA) + end + dα = @thunk begin + # TODO: this is an inner product implemented as a contraction + # for non-symmetric tensors this might be more efficient like this, + # but for symmetric tensors an intermediate object will anyways be created + # and then it might be more efficient to use an addition and inner product + tΔC = twist(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC))) + _dα = tensorscalar(tensorcontract(((), ()), A, ((), linearize(pC)), + _conj(conjA), tΔC, + (trivtuple(TO.numind(pC)), + ()), :N, One(), backend...)) + return projectα(_dα) + end + dβ = @thunk projectβ(inner(C, ΔC)) + dbackend = map(x -> NoTangent(), backend) + return NoTangent(), dC, NoTangent(), dA, NoTangent(), dα, dβ, dbackend... + end + + return C′, pullback +end + +function ChainRulesCore.rrule(::typeof(tensortrace!), C::AbstractTensorMap{S}, + pC::Index2Tuple, A::AbstractTensorMap{S}, + pA::Index2Tuple, conjA::Symbol, α::Number, β::Number, + backend::Backend...) where {S} + C′ = tensortrace!(copy(C), pC, A, pA, conjA, α, β, backend...) + + projectA = ProjectTo(A) + projectC = ProjectTo(C) + projectα = ProjectTo(α) + projectβ = ProjectTo(β) + + function pullback(ΔC′) + ΔC = unthunk(ΔC′) + dC = @thunk projectC(scale(ΔC, conj(β))) + dA = @thunk begin + ipC = invperm((linearize(pC)..., pA[1]..., pA[2]...)) + E = one!(TO.tensoralloc_add(scalartype(A), pA, A, conjA)) + twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) + _dA = zerovector(A, promote_scale(ΔC, α)) + _dA = tensorproduct!(_dA, (ipC, ()), ΔC, + (trivtuple(TO.numind(pC)), ()), conjA, E, + ((), trivtuple(TO.numind(pA))), conjA, + conjA == :N ? conj(α) : α, Zero(), backend...) + return projectA(_dA) + end + dα = @thunk begin + # TODO: this result might be easier to compute as: + # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α + At = tensortrace(pC, A, pA, conjA) + return projectα(inner(At, ΔC)) + end + dβ = @thunk projectβ(inner(C, ΔC)) + dbackend = map(x -> NoTangent(), backend) + return NoTangent(), dC, NoTangent(), dA, NoTangent(), NoTangent(), dα, dβ, + dbackend... + end + + return C′, pullback +end + end diff --git a/src/tensors/adjoint.jl b/src/tensors/adjoint.jl index 43ad110d..14776038 100644 --- a/src/tensors/adjoint.jl +++ b/src/tensors/adjoint.jl @@ -87,8 +87,8 @@ end # Show #------ -function Base.summary(t::AdjointTensorMap) - return print("AdjointTensorMap(", codomain(t), " ← ", domain(t), ")") +function Base.summary(io::IO, t::AdjointTensorMap) + return print(io, "AdjointTensorMap(", codomain(t), " ← ", domain(t), ")") end function Base.show(io::IO, t::AdjointTensorMap{S}) where {S<:IndexSpace} if get(io, :compact, false) diff --git a/src/tensors/indexmanipulations.jl b/src/tensors/indexmanipulations.jl index 0c202245..293d2705 100644 --- a/src/tensors/indexmanipulations.jl +++ b/src/tensors/indexmanipulations.jl @@ -226,8 +226,8 @@ end # Twist """ - twist!(t::AbstractTensorMap, i::Int; inv::Bool=false) - -> t + twist!(t::AbstractTensorMap, i::Int; inv::Bool=false) -> t + twist!(t::AbstractTensorMap, is; inv::Bool=false) -> t Apply a twist to the `i`th index of `t`, storing the result in `t`. If `inv=true`, use the inverse twist. @@ -248,17 +248,31 @@ function twist!(t::AbstractTensorMap, i::Int; inv::Bool=false) end return t end +function twist!(t::AbstractTensorMap, is; inv::Bool=false) + if !all(in(allind(t)), is) + msg = "Can't twist indices $is of a tensor with only $(numind(t)) indices." + throw(ArgumentError(msg)) + end + (BraidingStyle(sectortype(t)) == Bosonic() || isempty(is)) && return t + N₁ = numout(t) + for (f₁, f₂) in fusiontrees(t) + θ = prod(i -> i <= N₁ ? twist(f₁.uncoupled[i]) : twist(f₂.uncoupled[i - N₁]), is) + inv && (θ = θ') + rmul!(t[f₁, f₂], θ) + end + return t +end """ - twist(t::AbstractTensorMap, i::Int; inv::Bool=false) - -> t + twist(tsrc::AbstractTensorMap, i::Int; inv::Bool=false) -> tdst + twist(tsrc::AbstractTensorMap, is; inv::Bool=false) -> tdst -Apply a twist to the `i`th index of `t` and return the result as a new tensor. +Apply a twist to the `i`th index of `tsrc` and return the result as a new tensor. If `inv=true`, use the inverse twist. See [`twist!`](@ref) for storing the result in place. """ -twist(t::AbstractTensorMap, i::Int; inv::Bool=false) = twist!(copy(t), i; inv=inv) +twist(t::AbstractTensorMap, i; inv::Bool=false) = twist!(copy(t), i; inv) # Fusing and splitting # TODO: add functionality for easy fusing and splitting of tensor indices diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index d0538949..233b82dd 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -679,8 +679,8 @@ end # Show #------ -function Base.summary(t::TensorMap) - return print("TensorMap(", space(t), ")") +function Base.summary(io::IO, t::TensorMap) + return print(io, "TensorMap(", space(t), ")") end function Base.show(io::IO, t::TensorMap{S}) where {S<:IndexSpace} if get(io, :compact, false) diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index a2b6f171..1d3c7f30 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -285,13 +285,7 @@ function _contract!(α, A::AbstractTensorMap{S}, B::AbstractTensorMap{S}, end A′ = permute(A, (oindA, cindA); copy=copyA) B′ = permute(B, (cindB, oindB)) - if BraidingStyle(sectortype(S)) isa Fermionic - for i in domainind(A′) - if !isdual(space(A′, i)) - A′ = twist!(A′, i) - end - end - end + A′ = twist!(A′, filter(i -> !isdual(space(A′, i)), domainind(A′))) ipC = TupleTools.invperm((p₁..., p₂...)) oindAinC = TupleTools.getindices(ipC, ntuple(n -> n, N₁)) oindBinC = TupleTools.getindices(ipC, ntuple(n -> n + N₁, N₂)) diff --git a/test/ad.jl b/test/ad.jl index d686a834..41906932 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -4,7 +4,13 @@ using Random using FiniteDifferences using LinearAlgebra -## Test utility +const _repartition = @static if isdefined(Base, :get_extension) + Base.get_extension(TensorKit, :TensorKitChainRulesCoreExt)._repartition +else + TensorKit.TensorKitChainRulesCoreExt._repartition +end + +# Test utility # ------------- function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::AbstractTensorMap) return TensorMap(randn, scalartype(x), space(x)) @@ -20,35 +26,23 @@ function FiniteDifferences.to_vec(t::T) where {T<:TensorKit.TrivialTensorMap} return vec, x -> T(from_vec(x), codomain(t), domain(t)) end function FiniteDifferences.to_vec(t::AbstractTensorMap) - vec = mapreduce(vcat, blocks(t)) do (c, b) - if scalartype(t) <: Real - return reshape(b, :) .* sqrt(dim(c)) - else - v = reshape(b, :) .* sqrt(dim(c)) - return vcat(real(v), imag(v)) - end + vec = mapreduce(vcat, blocks(t); init=scalartype(t)[]) do (c, b) + return reshape(b, :) .* sqrt(dim(c)) end + vec_real = scalartype(t) <: Real ? vec : collect(reinterpret(real(scalartype(t)), vec)) - function from_vec(x) + function from_vec(x_real) + x = scalartype(t) <: Real ? x_real : reinterpret(scalartype(t), x_real) t′ = similar(t) - T = scalartype(t) ctr = 0 for (c, b) in blocks(t′) n = length(b) - if T <: Real - copyto!(b, reshape(x[(ctr + 1):(ctr + n)], size(b)) ./ sqrt(dim(c))) - else - v = x[(ctr + 1):(ctr + 2n)] - copyto!(b, - complex.(x[(ctr + 1):(ctr + n)], x[(ctr + n + 1):(ctr + 2n)]) ./ - sqrt(dim(c))) - end - ctr += T <: Real ? n : 2n + copyto!(b, reshape(view(x, ctr .+ (1:n)), size(b)) ./ sqrt(dim(c))) + ctr += n end return t′ end - - return vec, from_vec + return vec_real, from_vec end FiniteDifferences.to_vec(t::TensorKit.AdjointTensorMap) = to_vec(copy(t)) @@ -63,6 +57,12 @@ end precision(::Type{<:Union{Float32,Complex{Float32}}}) = 1e-2 precision(::Type{<:Union{Float64,Complex{Float64}}}) = 1e-6 +function randindextuple(N::Int, k::Int=rand(0:N)) + @assert 0 ≤ k ≤ N + _p = randperm(N) + return (tuple(_p[1:k]...), tuple(_p[(k + 1):end]...)) +end + # rrules for functions that destroy inputs # ---------------------------------------- function ChainRulesCore.rrule(::typeof(TensorKit.tsvd), args...; kwargs...) @@ -111,20 +111,25 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), ℂ[Z2Irrep](0 => 3, 1 => 2)', ℂ[Z2Irrep](0 => 2, 1 => 3), ℂ[Z2Irrep](0 => 2, 1 => 2)), + (ℂ[FermionParity](0 => 1, 1 => 1), + ℂ[FermionParity](0 => 1, 1 => 2)', + ℂ[FermionParity](0 => 2, 1 => 2)', + ℂ[FermionParity](0 => 2, 1 => 3), + ℂ[FermionParity](0 => 2, 1 => 2)), (ℂ[U1Irrep](0 => 2, 1 => 2, -1 => 2), ℂ[U1Irrep](0 => 3, 1 => 1, -1 => 1), ℂ[U1Irrep](0 => 2, 1 => 2, -1 => 1)', ℂ[U1Irrep](0 => 1, 1 => 2, -1 => 2), ℂ[U1Irrep](0 => 1, 1 => 3, -1 => 2)'), - (ℂ[SU2Irrep](0 => 3, 1 // 2 => 1), - ℂ[SU2Irrep](0 => 2, 1 => 1), + (ℂ[SU2Irrep](0 => 2, 1 // 2 => 1), + ℂ[SU2Irrep](0 => 1, 1 => 1), ℂ[SU2Irrep](1 // 2 => 1, 1 => 1)', - ℂ[SU2Irrep](0 => 2, 1 // 2 => 2), + ℂ[SU2Irrep](1 // 2 => 2), ℂ[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)')) -@testset "Automatic Differentiation with spacetype $(TensorKit.type_repr(eltype(V)))" verbose = true for V in - Vlist - @testset "Basic Linear Algebra with scalartype $T" for T in (Float64, ComplexF64) +@timedtestset "Automatic Differentiation with spacetype $(TensorKit.type_repr(eltype(V)))" verbose = true for V in + Vlist + @timedtestset "Basic Linear Algebra with scalartype $T" for T in (Float64, ComplexF64) A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) B = TensorMap(randn, T, space(A)) @@ -146,7 +151,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), test_rrule(⊗, D, E) end - @testset "Linear Algebra part II with scalartype $T" for T in (Float64, ComplexF64) + @timedtestset "Linear Algebra part II with scalartype $T" for T in (Float64, ComplexF64) for i in 1:3 E = TensorMap(randn, T, ⊗(V[1:i]...) ← ⊗(V[1:i]...)) test_rrule(LinearAlgebra.tr, E) @@ -157,71 +162,99 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), test_rrule(LinearAlgebra.norm, A, 2) end - @testset "TensorOperations with scalartype $T" for T in (Float64, ComplexF64) + @timedtestset "TensorOperations with scalartype $T" for T in (Float64, ComplexF64) atol = precision(T) rtol = precision(T) - @testset "tensortrace!" begin - A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[1] ⊗ V[5]) - pC = ((3, 5), (2,)) - pA = ((1,), (4,)) - α = randn(T) - β = randn(T) - - C = _randomize!(TensorOperations.tensoralloc_add(T, pC, A, :N, false)) - test_rrule(tensortrace!, C, pC, A, pA, :N, α, β; atol, rtol) - - C = _randomize!(TensorOperations.tensoralloc_add(T, pC, A, :C, false)) - test_rrule(tensortrace!, C, pC, A, pA, :C, α, β; atol, rtol) + @timedtestset "tensortrace!" begin + for _ in 1:5 + k1 = rand(0:3) + k2 = k1 == 3 ? 1 : rand(1:2) + V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) + V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) + + (_p, _q) = randindextuple(k1 + 2 * k2, k1) + p = _repartition(_p, rand(0:k1)) + q = _repartition(_q, k2) + ip = _repartition(invperm(linearize((_p, _q))), rand(0:(k1 + 2 * k2))) + A = TensorMap(randn, T, permute(prod(V1) ⊗ prod(V2) ← prod(V2), ip)) + + α = randn(T) + β = randn(T) + for conjA in (:N, :C) + C = _randomize!(TensorOperations.tensoralloc_add(T, p, A, conjA, false)) + test_rrule(tensortrace!, C, p, A, q, conjA, α, β; atol, rtol) + end + end end - @testset "tensoradd!" begin - p = ((1, 3, 2), (5, 4)) - A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) - C = _randomize!(TensorOperations.tensoralloc_add(T, p, A, :N, false)) + @timedtestset "tensoradd!" begin + A = TensorMap(randn, T, V[1] ⊗ V[2] ⊗ V[3] ← V[4] ⊗ V[5]) α = randn(T) β = randn(T) - test_rrule(tensoradd!, C, p, A, :N, α, β; atol, rtol) - - C = _randomize!(TensorOperations.tensoralloc_add(T, p, A, :C, false)) - test_rrule(tensoradd!, C, p, A, :C, α, β; atol, rtol) - end - @testset "tensorcontract!" begin - A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) - B = TensorMap(randn, T, V[3] ⊗ V[1]' ← V[2]) - pC = ((3, 2), (4, 1)) - pA = ((2, 4, 5), (1, 3)) - pB = ((2, 1), (3,)) - α = randn(T) - β = randn(T) + # repeat a couple times to get some distribution of arrows + for _ in 1:5 + p = randindextuple(length(V)) - C = _randomize!(TensorOperations.tensoralloc_contract(T, pC, A, pA, :N, - B, pB, :N, false)) - test_rrule(tensorcontract!, C, pC, A, pA, :N, B, pB, :N, α, β; atol, rtol) + C1 = _randomize!(TensorOperations.tensoralloc_add(T, p, A, :N, false)) + test_rrule(tensoradd!, C1, p, A, :N, α, β; atol, rtol) - A2 = TensorMap(randn, T, V[1]' ⊗ V[2]' ← V[3]' ⊗ V[4]' ⊗ V[5]') - C = _randomize!(TensorOperations.tensoralloc_contract(T, pC, A2, pA, :C, - B, pB, :N, false)) - test_rrule(tensorcontract!, C, pC, A2, pA, :C, B, pB, :N, α, β; atol, rtol) + C2 = _randomize!(TensorOperations.tensoralloc_add(T, p, A, :C, false)) + test_rrule(tensoradd!, C2, p, A, :C, α, β; atol, rtol) - B2 = TensorMap(randn, T, V[3]' ⊗ V[1] ← V[2]') - C = _randomize!(TensorOperations.tensoralloc_contract(T, pC, A, pA, :N, - B2, pB, :C, false)) - test_rrule(tensorcontract!, C, pC, A, pA, :N, B2, pB, :C, α, β; atol, rtol) + A = rand(Bool) ? C1 : C2 + end + end - C = _randomize!(TensorOperations.tensoralloc_contract(T, pC, A2, pA, :C, - B2, pB, :C, false)) - test_rrule(tensorcontract!, C, pC, A2, pA, :C, B2, pB, :C, α, β; atol, rtol) + @timedtestset "tensorcontract!" begin + for _ in 1:5 + d = 0 + local V1, V2, V3 + # retry a couple times to make sure there are at least some nonzero elements + for _ in 1:10 + k1 = rand(0:3) + k2 = rand(0:2) + k3 = rand(0:2) + V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init=one(V[1])) + V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init=one(V[1])) + V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init=one(V[1])) + d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3)) + d > 0 && break + end + ipA = randindextuple(length(V1) + length(V2)) + pA = _repartition(invperm(linearize(ipA)), length(V1)) + ipB = randindextuple(length(V2) + length(V3)) + pB = _repartition(invperm(linearize(ipB)), length(V2)) + pAB = randindextuple(length(V1) + length(V3)) + + α = randn(T) + β = randn(T) + V2_conj = prod(conj, V2; init=one(V[1])) + + for conjA in (:N, :C), conjB in (:N, :C) + A = TensorMap(randn, T, + permute(V1 ← (conjA === :C ? V2_conj : V2), ipA)) + B = TensorMap(randn, T, + permute((conjB === :C ? V2_conj : V2) ← V3, ipB)) + C = _randomize!(TensorOperations.tensoralloc_contract(T, pAB, A, pA, + conjA, + B, pB, conjB, + false)) + test_rrule(tensorcontract!, C, pAB, + A, pA, conjA, B, pB, conjB, + α, β; atol, rtol) + end + end end - @testset "tensorscalar" begin + @timedtestset "tensorscalar" begin A = Tensor(randn, T, ProductSpace{typeof(V[1]),0}()) test_rrule(tensorscalar, A) end end - @testset "Factorizations with scalartype $T" for T in (Float64, ComplexF64) + @timedtestset "Factorizations with scalartype $T" for T in (Float64, ComplexF64) A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) B = TensorMap(randn, T, space(A)') C = TensorMap(randn, T, V[1] ⊗ V[2] ← V[1] ⊗ V[2])