Skip to content

Commit

Permalink
Various improvements and fixes (#149)
Browse files Browse the repository at this point in the history
* Improvements for BraidingTensor

* Fix catdomain and catcodomain

* scalartype from AbstractTensorMap type instead of storagetype

* improve type stability allocator

* add `similar(::Type{AbstractTensorMap}, ...)`
  • Loading branch information
lkdvos authored Aug 28, 2024
1 parent 0acfb13 commit 9e78a03
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 10 deletions.
1 change: 1 addition & 0 deletions src/TensorKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ export CompositeSpace, ProductSpace # composite spaces
export FusionTree
export IndexSpace, TensorSpace, TensorMapSpace
export AbstractTensorMap, AbstractTensor, TensorMap, Tensor, TrivialTensorMap # tensors and tensor properties
export BraidingTensor
export TruncationScheme
export SpaceMismatch, SectorMismatch, IndexError # error types

Expand Down
9 changes: 9 additions & 0 deletions src/tensors/abstracttensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,15 @@ function Base.similar(::AbstractTensorMap, ::Type{TorA},
return TT(undef, codomain(P), domain(P))
end

# implementation in type-domain
function Base.similar(::Type{TT}, P::TensorMapSpace) where {TT<:AbstractTensorMap}
return TensorMap{scalartype(TT)}(undef, P)
end
function Base.similar(::Type{TT}, cod::TensorSpace{S},
dom::TensorSpace{S}) where {TT<:AbstractTensorMap,S}
return TensorMap{scalartype(TT)}(undef, cod, dom)
end

# Equality and approximality
#----------------------------
function Base.:(==)(t1::AbstractTensorMap, t2::AbstractTensorMap)
Expand Down
20 changes: 17 additions & 3 deletions src/tensors/braidingtensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ struct BraidingTensor{T,S} <: AbstractTensorMap{T,S,2,2}
# partial construction: only construct rowr and colr when needed
end
end
function BraidingTensor{T}(V1::S, V2::S, adjoint::Bool=false) where {T,S<:IndexSpace}
return BraidingTensor{T,S}(V1, V2, adjoint)
end
function BraidingTensor(V1::S, V2::S, adjoint::Bool=false) where {S<:IndexSpace}
if BraidingStyle(sectortype(S)) isa SymmetricBraiding
return BraidingTensor{Float64,S}(V1, V2, adjoint)
Expand All @@ -38,7 +41,12 @@ end
function BraidingTensor(V::HomSpace, adjoint::Bool=false)
domain(V) == reverse(codomain(V)) ||
throw(SpaceMismatch("Cannot define a braiding on $V"))
return BraidingTensor(V[1], V[2], adjoint)
return BraidingTensor(V[2], V[1], adjoint)
end
function BraidingTensor{T}(V::HomSpace, adjoint::Bool=false) where {T}
domain(V) == reverse(codomain(V)) ||
throw(SpaceMismatch("Cannot define a braiding on $V"))
return BraidingTensor{T}(V[2], V[1], adjoint)
end
function Base.adjoint(b::BraidingTensor{T,S}) where {T,S}
return BraidingTensor{T,S}(b.V1, b.V2, !b.adjoint)
Expand All @@ -54,6 +62,10 @@ blocksectors(b::BraidingTensor) = blocksectors(b.V1 ⊗ b.V2)
hasblock(b::BraidingTensor, s::Sector) = s blocksectors(b)

function fusiontrees(b::BraidingTensor)
if sectortype(b) === Trivial
return ((nothing, nothing),)
end

codom = codomain(b)
dom = domain(b)
I = sectortype(b)
Expand All @@ -71,7 +83,6 @@ function fusiontrees(b::BraidingTensor)
offset1 = last(r)
end
end
dim1 = offset1
offset2 = 0
for s2 in sectors(dom)
for f₂ in fusiontrees(s2, c, map(isdual, dom.spaces))
Expand All @@ -80,7 +91,6 @@ function fusiontrees(b::BraidingTensor)
offset2 = last(r)
end
end
dim2 = offset2
push!(rowr, c => rowrc)
push!(colr, c => colrc)
end
Expand Down Expand Up @@ -124,6 +134,10 @@ end
return sreshape(StridedView(data), d)
end
end
@inline function Base.getindex(b::BraidingTensor, ::Nothing, ::Nothing)
sectortype(b) === Trivial || throw(SectorMismatch())
return getindex(b)
end

# efficient copy constructor
Base.copy(b::BraidingTensor) = b
Expand Down
10 changes: 6 additions & 4 deletions src/tensors/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
end

# concatenate tensors
function catdomain(t1::T, t2::T) where {S,N₁,T<:AbstractTensorMap{<:Any,S,N₁,1}}
function catdomain(t1::TT, t2::TT) where {S,N₁,TT<:AbstractTensorMap{<:Any,S,N₁,1}}
codomain(t1) == codomain(t2) ||
throw(SpaceMismatch("codomains of tensors to concatenate must match:\n" *
"$(codomain(t1))$(codomain(t2))"))
Expand All @@ -411,14 +411,15 @@ function catdomain(t1::T, t2::T) where {S,N₁,T<:AbstractTensorMap{<:Any,S,N₁
throw(SpaceMismatch("cannot horizontally concatenate tensors whose domain has non-matching duality"))

V = V1 V2
t = TensorMap(undef, promote_type(scalartype(t1), scalartype(t2)), codomain(t1), V)
T = promote_type(scalartype(t1), scalartype(t2))
t = TensorMap{T}(undef, codomain(t1), V)
for c in sectors(V)
block(t, c)[:, 1:dim(V1, c)] .= block(t1, c)
block(t, c)[:, dim(V1, c) .+ (1:dim(V2, c))] .= block(t2, c)
end
return t
end
function catcodomain(t1::T, t2::T) where {S,N₂,T<:AbstractTensorMap{<:Any,S,1,N₂}}
function catcodomain(t1::TT, t2::TT) where {S,N₂,TT<:AbstractTensorMap{<:Any,S,1,N₂}}
domain(t1) == domain(t2) ||
throw(SpaceMismatch("domains of tensors to concatenate must match:\n" *
"$(domain(t1))$(domain(t2))"))
Expand All @@ -428,7 +429,8 @@ function catcodomain(t1::T, t2::T) where {S,N₂,T<:AbstractTensorMap{<:Any,S,1,
throw(SpaceMismatch("cannot vertically concatenate tensors whose codomain has non-matching duality"))

V = V1 V2
t = TensorMap(undef, promote_type(scalartype(t1), scalartype(t2)), V, domain(t1))
T = promote_type(scalartype(t1), scalartype(t2))
t = TensorMap{T}(undef, V, domain(t1))
for c in sectors(V)
block(t, c)[1:dim(V1, c), :] .= block(t1, c)
block(t, c)[dim(V1, c) .+ (1:dim(V2, c)), :] .= block(t2, c)
Expand Down
5 changes: 3 additions & 2 deletions src/tensors/tensoroperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function TO.tensoralloc(::Type{TT}, structure::TensorMapSpace{S,N₁,N₂}, iste
colr, coldims = _buildblockstructure(domain(structure), blocksectoriterator)
A = storagetype(TT)
blockallocator(c) = TO.tensoralloc(A, (rowdims[c], coldims[c]), istemp, allocator)
data = SectorDict(c => blockallocator(c) for c in blocksectoriterator)
data = SectorDict{sectortype(TT),A}(c => blockallocator(c) for c in blocksectoriterator)
return TT(data, codomain(structure), domain(structure), rowr, colr)
end

Expand Down Expand Up @@ -127,7 +127,8 @@ function TO.tensorcontract_type(TC,
B::AbstractTensorMap, ::Index2Tuple, ::Bool,
::Index2Tuple{N₁,N₂}) where {N₁,N₂}
M = similarstoragetype(A, TC)
M == similarstoragetype(B, TC) || throw(ArgumentError("incompatible storage types"))
M == similarstoragetype(B, TC) ||
throw(ArgumentError("incompatible storage types:\n$(M)$(similarstoragetype(B, TC))"))
spacetype(A) == spacetype(B) || throw(SpaceMismatch("incompatible space types"))
return tensormaptype(spacetype(A), N₁, N₂, M)
end
Expand Down
2 changes: 1 addition & 1 deletion src/tensors/vectorinterface.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# scalartype
#------------
VectorInterface.scalartype(T::Type{<:AbstractTensorMap}) = scalartype(storagetype(T))
VectorInterface.scalartype(::Type{TT}) where {T,TT<:AbstractTensorMap{T}} = scalartype(T)

# zerovector & zerovector!!
#---------------------------
Expand Down
19 changes: 19 additions & 0 deletions test/planar.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using TensorKit, TensorOperations, Test
using TensorKit: BraidingTensor
using TensorKit: planaradd!, planartrace!, planarcontract!
using TensorKit: PlanarTrivial, ℙ

Expand Down Expand Up @@ -29,6 +30,24 @@ function force_planar(tsrc::TensorMap{<:Any,<:GradedSpace})
return tdst
end

@testset "Braiding tensor" begin
V1 =^2 ^3 ^3 ^2
t1 = @constinferred BraidingTensor(V1)
@test space(t1) == V1
@test codomain(t1) == codomain(V1)
@test domain(t1) == domain(V1)
@test scalartype(t1) == Float64
@test storagetype(t1) == Matrix{Float64}
t2 = @constinferred BraidingTensor{ComplexF64}(V1)
@test scalartype(t2) == ComplexF64
@test storagetype(t2) == Matrix{ComplexF64}

V2 =^2 ^3 ^2 ^3
@test_throws SpaceMismatch BraidingTensor(V2)

@test adjoint(t1) isa BraidingTensor
end

@testset "planar methods" verbose = true begin
@testset "planaradd" begin
A = randn(ℂ^2 ^3 ^6 ^5 ^4)
Expand Down

0 comments on commit 9e78a03

Please sign in to comment.