Skip to content

Commit

Permalink
insertleftunit, insertrightunit and removeunit (#187)
Browse files Browse the repository at this point in the history
* extend `insertunit` to `HomSpace`

* extend `insertunit` to `AbstractTensorMap`

* Add `insertunit` tests for `HomSpace`

* Add `insertunit` tests for `TensorMap`

* Add `removeunit` functionality

* Add tests `removeunit`

* fixup! Add `insertunit` tests for `TensorMap`

* improve type stability

* Rewrite in terms of `insertleftunit` and `insertrightunit`

* also update docs

* fix missing kwargs

* update checks in hope of fixing type ambiguity

* type stability changes
  • Loading branch information
lkdvos authored Dec 18, 2024
1 parent c041bfe commit ef42572
Show file tree
Hide file tree
Showing 10 changed files with 237 additions and 12 deletions.
3 changes: 2 additions & 1 deletion docs/src/lib/spaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ fuse
ismonomorphic
isepimorphic
isisomorphic
insertunit
```

There are also specific methods for `HomSpace` instances, that are used in determining
Expand All @@ -116,4 +115,6 @@ the resuling `HomSpace` after applying certain tensor operations.
TensorKit.permute(::HomSpace{S}, ::Index2Tuple{N₁,N₂}) where {S,N₁,N₂}
TensorKit.select(::HomSpace{S}, ::Index2Tuple{N₁,N₂}) where {S,N₁,N₂}
TensorKit.compose(::HomSpace{S}, ::HomSpace{S}) where {S}
insertleftunit(::HomSpace, ::Int)
insertrightunit(::HomSpace, ::Int)
```
4 changes: 3 additions & 1 deletion docs/src/lib/tensors.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ braid(::AbstractTensorMap, ::Index2Tuple, ::IndexTuple; ::Bool)
transpose(::AbstractTensorMap, ::Index2Tuple; ::Bool)
repartition(::AbstractTensorMap, ::Int, ::Int; ::Bool)
twist(::AbstractTensorMap, ::Int; ::Bool)
insertleftunit(::AbstractTensorMap, ::Int)
insertrightunit(::AbstractTensorMap, ::Int)
```

```@docs
Expand Down Expand Up @@ -224,4 +226,4 @@ and only accept the `TensorMap` object as well as the method-specific algorithm
arguments.


TODO: document svd truncation types
TODO: document svd truncation types
3 changes: 2 additions & 1 deletion src/TensorKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ export TruncationScheme
export SpaceMismatch, SectorMismatch, IndexError # error types

# general vector space methods
export space, field, dual, dim, reduceddim, dims, fuse, flip, isdual, insertunit, oplus
export space, field, dual, dim, reduceddim, dims, fuse, flip, isdual, oplus,
insertleftunit, insertrightunit, removeunit

# partial order for vector spaces
export infimum, supremum, isisomorphic, ismonomorphic, isepimorphic
Expand Down
3 changes: 3 additions & 0 deletions src/auxiliary/auxiliary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ function _kron(A, B)
return C
end

@noinline _boundserror(P, i) = throw(BoundsError(P, i))
@noinline _nontrivialspaceerror(P, i) = throw(ArgumentError(lazy"Attempting to remove a non-trivial space $(P[i])"))

# Compat implementation:
@static if VERSION < v"1.7"
macro constprop(setting, ex)
Expand Down
2 changes: 2 additions & 0 deletions src/auxiliary/deprecate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,6 @@ end

Base.@deprecate EuclideanProduct() EuclideanInnerProduct()

Base.@deprecate insertunit(P::ProductSpace, args...; kwargs...) insertleftunit(args...; kwargs...)

#! format: on
50 changes: 50 additions & 0 deletions src/spaces/homspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,56 @@ function compose(W::HomSpace{S}, V::HomSpace{S}) where {S}
return HomSpace(codomain(W), domain(V))
end

"""
insertleftunit(W::HomSpace, i::Int=numind(W) + 1; conj=false, dual=false)
Insert a trivial vector space, isomorphic to the underlying field, at position `i`.
More specifically, adds a left monoidal unit or its dual.
See also [`insertrightunit`](@ref), [`removeunit`](@ref).
"""
@constprop :aggressive function insertleftunit(W::HomSpace, i::Int=numind(W) + 1;
conj::Bool=false, dual::Bool=false)
if i numout(W)
return insertleftunit(codomain(W), i; conj, dual) domain(W)
else
return codomain(W) insertleftunit(domain(W), i - numout(W); conj, dual)
end
end

"""
insertrightunit(W::HomSpace, i::Int=numind(W); conj=false, dual=false)
Insert a trivial vector space, isomorphic to the underlying field, after position `i`.
More specifically, adds a right monoidal unit or its dual.
See also [`insertleftunit`](@ref), [`removeunit`](@ref).
"""
@constprop :aggressive function insertrightunit(W::HomSpace, i::Int=numind(W);
conj::Bool=false, dual::Bool=false)
if i numout(W)
return insertrightunit(codomain(W), i; conj, dual) domain(W)
else
return codomain(W) insertrightunit(domain(W), i - numout(W); conj, dual)
end
end

"""
removeunit(P::HomSpace, i::Int)
This removes a trivial tensor product factor at position `1 ≤ i ≤ N`.
For this to work, that factor has to be isomorphic to the field of scalars.
This operation undoes the work of [`insertleftunit`](@ref) or [`insertrightunit`](@ref).
"""
@constprop :aggressive function removeunit(P::HomSpace, i::Int)
if i numout(P)
return removeunit(codomain(P), i) domain(P)
else
return codomain(P) removeunit(domain(P), i - numout(P))
end
end

# Block and fusion tree ranges: structure information for building tensors
#--------------------------------------------------------------------------
struct FusionBlockStructure{I,N,F₁,F₂}
Expand Down
48 changes: 41 additions & 7 deletions src/spaces/productspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,15 +246,15 @@ fuse(P::ProductSpace{S,0}) where {S<:ElementarySpace} = oneunit(S)
fuse(P::ProductSpace{S}) where {S<:ElementarySpace} = fuse(P.spaces...)

"""
insertunit(P::ProductSpace, i::Int = length(P)+1; dual = false, conj = false)
insertleftunit(P::ProductSpace, i::Int=length(P) + 1; conj=false, dual=false)
For `P::ProductSpace{S,N}`, this adds an extra tensor product factor at position
`1 <= i <= N+1` (last position by default) which is just the `S`-equivalent of the
underlying field of scalars, i.e. `oneunit(S)`. With the keyword arguments, one can choose
to insert the conjugated or dual space instead, which are all isomorphic to the field of
scalars.
Insert a trivial vector space, isomorphic to the underlying field, at position `i`.
More specifically, adds a left monoidal unit or its dual.
See also [`insertrightunit`](@ref), [`removeunit`](@ref).
"""
function insertunit(P::ProductSpace, i::Int=length(P) + 1; dual=false, conj=false)
function insertleftunit(P::ProductSpace, i::Int=length(P) + 1;
conj::Bool=false, dual::Bool=false)
u = oneunit(spacetype(P))
if dual
u = TensorKit.dual(u)
Expand All @@ -265,6 +265,40 @@ function insertunit(P::ProductSpace, i::Int=length(P) + 1; dual=false, conj=fals
return ProductSpace(TupleTools.insertafter(P.spaces, i - 1, (u,)))
end

"""
insertrightunit(P::ProductSpace, i::Int=lenght(P); conj=false, dual=false)
Insert a trivial vector space, isomorphic to the underlying field, after position `i`.
More specifically, adds a right monoidal unit or its dual.
See also [`insertleftunit`](@ref), [`removeunit`](@ref).
"""
function insertrightunit(P::ProductSpace, i::Int=length(P);
conj::Bool=false, dual::Bool=false)
u = oneunit(spacetype(P))
if dual
u = TensorKit.dual(u)
end
if conj
u = TensorKit.conj(u)
end
return ProductSpace(TupleTools.insertafter(P.spaces, i, (u,)))
end

"""
removeunit(P::ProductSpace, i::Int)
This removes a trivial tensor product factor at position `1 ≤ i ≤ N`.
For this to work, that factor has to be isomorphic to the field of scalars.
This operation undoes the work of [`insertunit`](@ref).
"""
function removeunit(P::ProductSpace, i::Int)
1 i length(P) || _boundserror(P, i)
isisomorphic(P[i], oneunit(P[i])) || _nontrivialspaceerror(P, i)
return ProductSpace{spacetype(P)}(TupleTools.deleteat(P.spaces, i))
end

# Functionality for extracting and iterating over spaces
#--------------------------------------------------------
Base.length(P::ProductSpace) = length(P.spaces)
Expand Down
78 changes: 78 additions & 0 deletions src/tensors/indexmanipulations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,84 @@ See [`twist!`](@ref) for storing the result in place.
"""
twist(t::AbstractTensorMap, i; inv::Bool=false) = twist!(copy(t), i; inv)

"""
insertleftunit(tsrc::AbstractTensorMap, i::Int=numind(t) + 1;
conj=false, dual=false, copy=false) -> tdst
Insert a trivial vector space, isomorphic to the underlying field, at position `i`.
More specifically, adds a left monoidal unit or its dual.
If `copy=false`, `tdst` might share data with `tsrc` whenever possible. Otherwise, a copy is always made.
See also [`insertrightunit`](@ref) and [`removeunit`](@ref).
"""
@constprop :aggressive function insertleftunit(t::AbstractTensorMap,
i::Int=numind(t) + 1; copy::Bool=true,
conj::Bool=false, dual::Bool=false)
W = insertleftunit(space(t), i; conj, dual)
tdst = similar(t, W)
for (c, b) in blocks(t)
copy!(block(tdst, c), b)
end
return tdst
end
@constprop :aggressive function insertleftunit(t::TensorMap, i::Int=numind(t) + 1;
copy::Bool=false,
conj::Bool=false, dual::Bool=false)
W = insertleftunit(space(t), i; conj, dual)
return TensorMap{scalartype(t)}(copy ? Base.copy(t.data) : t.data, W)
end

"""
insertrightunit(tsrc::AbstractTensorMap, i::Int=numind(t);
conj=false, dual=false, copy=false) -> tdst
Insert a trivial vector space, isomorphic to the underlying field, after position `i`.
More specifically, adds a right monoidal unit or its dual.
If `copy=false`, `tdst` might share data with `tsrc` whenever possible. Otherwise, a copy is always made.
See also [`insertleftunit`](@ref) and [`removeunit`](@ref).
"""
@constprop :aggressive function insertrightunit(t::AbstractTensorMap, i::Int=numind(t);
copy::Bool=true, kwargs...)
W = insertrightunit(space(t), i; kwargs...)
tdst = similar(t, W)
for (c, b) in blocks(t)
copy!(block(tdst, c), b)
end
return tdst
end
@constprop :aggressive function insertrightunit(t::TensorMap, i::Int=numind(t);
copy::Bool=false, kwargs...)
W = insertrightunit(space(t), i; kwargs...)
return TensorMap{scalartype(t)}(copy ? Base.copy(t.data) : t.data, W)
end

"""
removeunit(tsrc::AbstractTensorMap, i::Int; copy=false) -> tdst
This removes a trivial tensor product factor at position `1 ≤ i ≤ N`.
For this to work, that factor has to be isomorphic to the field of scalars.
If `copy=false`, `tdst` might share data with `tsrc` whenever possible. Otherwise, a copy is always made.
This operation undoes the work of [`insertunit`](@ref).
"""
@constprop :aggressive function removeunit(t::TensorMap, i::Int; copy::Bool=false)
W = removeunit(space(t), i)
return TensorMap{scalartype(t)}(copy ? Base.copy(t.data) : t.data, W)
end
@constprop :aggressive function removeunit(t::AbstractTensorMap, i::Int;
copy::Bool=true)
W = removeunit(space(t), i)
tdst = similar(t, W)
for (c, b) in blocks(t)
copy!(block(tdst, c), b)
end
return tdst
end

# Fusing and splitting
# TODO: add functionality for easy fusing and splitting of tensor indices

Expand Down
26 changes: 24 additions & 2 deletions test/spaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,10 @@ println("------------------------------------")
@test @constinferred((V1 V2, V3 V4)) == P
@test @constinferred((V1, V2, V3 V4)) == P
@test @constinferred((V1, V2 V3, V4)) == P
@test @constinferred(insertunit(P, 3)) == V1 * V2 * oneunit(V1) * V3 * V4
@test V1 * V2 * oneunit(V1) * V3 * V4 ==
@constinferred(insertleftunit(P, 3)) ==
@constinferred(insertrightunit(P, 2))
@test @constinferred(removeunit(V1 * V2 * oneunit(V1)' * V3 * V4, 3)) == P
@test fuse(V1, V2', V3) V1 V2' V3
@test fuse(V1, V2', V3) V1 V2' V3
@test fuse(V1, V2', V3) V1 V2' V3
Expand Down Expand Up @@ -338,7 +341,10 @@ println("------------------------------------")
@test @constinferred(*(V1, V2, V3)) == P
@test @constinferred((V1, V2, V3)) == P
@test @constinferred(adjoint(P)) == dual(P) == V3' V2' V1'
@test @constinferred(insertunit(P, 3; conj=true)) == V1 * V2 * oneunit(V1)' * V3
@test V1 * V2 * oneunit(V1)' * V3 ==
@constinferred(insertleftunit(P, 3; conj=true)) ==
@constinferred(insertrightunit(P, 2; conj=true))
@test P == @constinferred(removeunit(insertleftunit(P, 3), 3))
@test fuse(V1, V2', V3) V1 V2' V3
@test fuse(V1, V2', V3) V1 V2' V3 fuse(V1 V2' V3)
@test fuse(V1, V2') V3 V1 V2' V3
Expand Down Expand Up @@ -419,5 +425,21 @@ println("------------------------------------")
@test W == @constinferred permute(W, ((1, 2), (3, 4, 5)))
@test permute(W, ((2, 4, 5), (3, 1))) == (V2 V4' V5' V3 V1')
@test (V1 V2 V1 V2) == @constinferred TensorKit.compose(W, W')
@test (V1 V2 V3 V4 V5 oneunit(V5)) ==
@constinferred(insertleftunit(W)) ==
@constinferred(insertrightunit(W))
@test @constinferred(removeunit(insertleftunit(W), $(numind(W) + 1))) == W
@test (V1 V2 V3 V4 V5 oneunit(V5)') ==
@constinferred(insertleftunit(W; conj=true)) ==
@constinferred(insertrightunit(W; conj=true))
@test (oneunit(V1) V1 V2 V3 V4 V5) ==
@constinferred(insertleftunit(W, 1)) ==
@constinferred(insertrightunit(W, 0))
@test (V1 V2 oneunit(V1) V3 V4 V5) ==
@constinferred(insertrightunit(W, 2))
@test (V1 V2 oneunit(V1) V3 V4 V5) == @constinferred(insertleftunit(W, 3))
@test @constinferred(removeunit(insertleftunit(W, 3), 3)) == W
@test @constinferred(insertrightunit(one(V1) V1, 0)) == (oneunit(V1) V1)
@test_throws BoundsError insertleftunit(one(V1) V1, 0)
end
end
32 changes: 32 additions & 0 deletions test/tensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,38 @@ for V in spacelist
@test w * w' == (w * w')^2
end
end
@timedtestset "Trivial spaces" begin
W = V1 V2 V3 V4 V5
for T in (Float32, ComplexF64)
t = @constinferred rand(T, W)
t2 = @constinferred insertleftunit(t)
@test t2 == @constinferred insertrightunit(t)
@test numind(t2) == numind(t) + 1
@test space(t2) == insertleftunit(space(t))
@test scalartype(t2) === T
@test t.data === t2.data
@test @constinferred(removeunit(t2, $(numind(t2)))) == t
t3 = @constinferred insertleftunit(t; copy=true)
@test t3 == @constinferred insertrightunit(t; copy=true)
@test t.data !== t3.data
for (c, b) in blocks(t)
@test b == block(t3, c)
end
@test @constinferred(removeunit(t3, $(numind(t3)))) == t
t4 = @constinferred insertrightunit(t, 3; dual=true)
@test numin(t4) == numin(t) && numout(t4) == numout(t) + 1
for (c, b) in blocks(t)
@test b == block(t4, c)
end
@test @constinferred(removeunit(t4, 4)) == t
t5 = @constinferred insertleftunit(t, 4; dual=true)
@test numin(t5) == numin(t) + 1 && numout(t5) == numout(t)
for (c, b) in blocks(t)
@test b == block(t5, c)
end
@test @constinferred(removeunit(t5, 4)) == t
end
end
if hasfusiontensor(I)
@timedtestset "Basic linear algebra: test via conversion" begin
W = V1 V2 V3 V4 V5
Expand Down

0 comments on commit ef42572

Please sign in to comment.