Skip to content

Commit

Permalink
Add repartition(!)(::AbstractTensorMap) (#116)
Browse files Browse the repository at this point in the history
* Add `repartition(!)` for TensorMaps

* Add tests for `repartition(!)`

* Improve type-inference

* Circumvent `@constprop` in Julia 1.6

* Update documentation
  • Loading branch information
lkdvos authored Apr 11, 2024
1 parent 27bf150 commit 370dd92
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/src/lib/tensors.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,14 @@ scaling, as well as the selection of a custom backend.
permute(t::AbstractTensorMap{S}, (p₁, p₂)::Index2Tuple{N₁,N₂}; copy::Bool=false) where {S,N₁,N₂}
braid(t::AbstractTensorMap{S}, (p₁, p₂)::Index2Tuple, levels::IndexTuple; copy::Bool=false) where {S}
transpose(::AbstractTensorMap, ::Index2Tuple)
repartition(::AbstractTensorMap, ::Int, ::Int)
twist(::AbstractTensorMap, ::Int)
```
```@docs
permute!(tdst::AbstractTensorMap{S,N₁,N₂}, tsrc::AbstractTensorMap{S}, p::Index2Tuple{N₁,N₂}) where {S,N₁,N₂}
braid!
transpose!
repartition!(::AbstractTensorMap{S}, ::AbstractTensorMap{S}) where {S}
twist!
```
```@docs
Expand Down
3 changes: 2 additions & 1 deletion src/TensorKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ export leftorth, rightorth, leftnull, rightnull,
leftorth!, rightorth!, leftnull!, rightnull!,
tsvd!, tsvd, eigen, eigen!, eig, eig!, eigh, eigh!, exp, exp!,
isposdef, isposdef!, ishermitian, sylvester
export braid, braid!, permute, permute!, transpose, transpose!, twist, twist!
export braid, braid!, permute, permute!, transpose, transpose!, twist, twist!, repartition,
repartition!
export catdomain, catcodomain

export OrthogonalFactorizationAlgorithm, QR, QRpos, QL, QLpos, LQ, LQpos, RQ, RQpos,
Expand Down
9 changes: 9 additions & 0 deletions src/auxiliary/auxiliary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,12 @@ function _kron(A, B)
end
return C
end

# Compat implementation:
@static if VERSION < v"1.7"
macro constprop(setting, ex)
return esc(ex)
end
else
using Base: @constprop
end
40 changes: 40 additions & 0 deletions src/tensors/indexmanipulations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,46 @@ function LinearAlgebra.transpose(t::AdjointTensorMap{S},
return adjoint(transpose(adjoint(t), (p₁′, p₂′); copy=copy))
end

"""
repartition!(tdst::AbstractTensorMap{S}, tsrc::AbstractTensorMap{S}) where {S} -> tdst
Write into `tdst` the result of repartitioning the indices of `tsrc`. This is just a special
case of a transposition that only changes the number of in- and outgoing indices.
See [`repartition`](@ref) for creating a new tensor.
"""
function repartition!(tdst::AbstractTensorMap{S}, tsrc::AbstractTensorMap{S}) where {S}
numind(tsrc) == numind(tdst) ||
throw(ArgumentError("tsrc and tdst should have an equal amount of indices"))
all_inds = (codomainind(tsrc)..., reverse(domainind(tsrc))...)
p₁ = ntuple(i -> all_inds[i], numout(tdst))
p₂ = reverse(ntuple(i -> all_inds[i + numout(tdst)], numin(tdst)))
return transpose!(tdst, tsrc, (p₁, p₂))
end

"""
repartition(tsrc::AbstractTensorMap{S}, N₁::Int, N₂::Int; copy::Bool=false) where {S}
-> tdst::AbstractTensorMap{S,N₁,N₂}
Return tensor `tdst` obtained by repartitioning the indices of `t`.
The codomain and domain of `tdst` correspond to the first `N₁` and last `N₂` spaces of `t`,
respectively.
If `copy=false`, `tdst` might share data with `tsrc` whenever possible. Otherwise, a copy is always made.
To repartition into an existing destination, see [repartition!](@ref).
"""
@constprop :aggressive function repartition(t::AbstractTensorMap, N₁::Int,
N₂::Int=numind(t) - N₁;
copy::Bool=false)
N₁ + N₂ == numind(t) ||
throw(ArgumentError("Invalid repartition: $(numind(t)) to ($N₁, $N₂)"))
all_inds = (codomainind(t)..., reverse(domainind(t))...)
p₁ = ntuple(i -> all_inds[i], N₁)
p₂ = reverse(ntuple(i -> all_inds[i + N₁], N₂))
return transpose(t, (p₁, p₂); copy)
end

# Twist
"""
twist!(t::AbstractTensorMap, i::Int; inv::Bool=false)
Expand Down
16 changes: 15 additions & 1 deletion test/tensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,22 +201,36 @@ for V in spacelist
t2′ = permute(t′, (p1, p2))
@test dot(t2′, t2) dot(t′, t) dot(transpose(t2′), transpose(t2))
end

t3 = VERSION < v"1.7" ? repartition(t, k) :
@constinferred repartition(t, $k)
@test norm(t3) norm(t)
t3′ = @constinferred repartition!(similar(t3), t′)
@test norm(t3′) norm(t′)
@test dot(t′, t) dot(t3′, t3)
end
end
if BraidingStyle(I) isa Bosonic && hasfusiontensor(I)
@timedtestset "Permutations: test via conversion" begin
W = V1 V2 V3 V4 V5
t = Tensor(rand, ComplexF64, W)
a = convert(Array, t)
for k in 0:5
for p in permutations(1:5)
p1 = ntuple(n -> p[n], k)
p2 = ntuple(n -> p[k + n], 5 - k)
t2 = permute(t, (p1, p2))
a2 = convert(Array, t2)
@test a2 permutedims(convert(Array, t), (p1..., p2...))
@test a2 permutedims(a, (p1..., p2...))
@test convert(Array, transpose(t2))
permutedims(a2, (5, 4, 3, 2, 1))
end

t3 = repartition(t, k)
a3 = convert(Array, t3)
@test a3 permutedims(a,
(ntuple(identity, k)...,
reverse(ntuple(i -> i + k, 5 - k))...))
end
end
end
Expand Down

0 comments on commit 370dd92

Please sign in to comment.