Skip to content

Commit

Permalink
Add copy methods and more mixed TensorOperations methods
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Dec 4, 2023
1 parent c8c98a6 commit 7761009
Showing 1 changed file with 61 additions and 19 deletions.
80 changes: 61 additions & 19 deletions src/blocktensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,21 @@ function Base.convert(::Type{<:AbstractTensorMap{S,N₁,N₂}}, t::BlockTensorMa
return tdst
end

# Utility
# -------

function Base.copy(t::BlockTensorMap{S,N₁,N₂,T,N}) where {S,N₁,N₂,T,N}
return BlockTensorMap{S,N₁,N₂,T,N}(copy(t.data), codomain(t), domain(t))
end
function Base.deepcopy(t::BlockTensorMap{S,N₁,N₂,T,N}) where {S,N₁,N₂,T,N}
return BlockTensorMap{S,N₁,N₂,T,N}(deepcopy(t.data), codomain(t), domain(t))
end

# TensorKit Interface
# -------------------

TK.spacetype(::Union{T,Type{T}}) where {S,T<:BlockTensorMap{S}} = S
function TK.sectortype(::Union{T,Type{T}}) where {S,T<:BlockTensorMap{S}}
TK.spacetype(::Union{T,Type{<:T}}) where {S,T<:BlockTensorMap{S}} = S
function TK.sectortype(::Union{T,Type{<:T}}) where {S,T<:BlockTensorMap{S}}
return sectortype(S)
end
TK.storagetype(::Union{B,Type{B}}) where {T,B<:BlockTensorArray{T}} = storagetype(T)
Expand All @@ -338,9 +348,7 @@ TK.similarstoragetype(t::BlockTensorMap, T) = TK.similarstoragetype(typeof(t), T

TK.numout(::Union{T,Type{T}}) where {S,N₁,T<:BlockTensorMap{S,N₁}} = N₁
TK.numin(::Union{T,Type{T}}) where {S,N₁,N₂,T<:BlockTensorMap{S,N₁,N₂}} = N₂
function TK.numind(::Union{T,Type{T}}) where {S,N₁,N₂,T<:BlockTensorMap{S,N₁,N₂}}
return N₁ + N₂
end
TK.numind(::Union{T,Type{T}}) where {S,N₁,N₂,T<:BlockTensorMap{S,N₁,N₂}} = N₁ + N₂

TK.codomain(t::BlockTensorMap) = t.codom
TK.domain(t::BlockTensorMap) = t.dom
Expand Down Expand Up @@ -591,9 +599,6 @@ function VI.inner(x::BlockTensorMap, y::BlockTensorMap)
return s
end

# TODO: this is type-piracy!
# VI.scalartype(::Type{Union{A,B}}) where {A,B} = Union{scalartype(A), scalartype(B)}

# TensorOperations
# ----------------

Expand Down Expand Up @@ -767,6 +772,41 @@ function TO.tensorstructure(t::BlockTensorMap, iA::Int, conjA::Symbol)
return conjA == :N ? space(t, iA) : conj(space(t, iA))
end

function TO.tensorcontract_structure(pC::Index2Tuple{N₁,N₂},
A::BlockTensorMap{S}, pA::Index2Tuple, conjA::Symbol,
B::BlockTensorMap{S}, pB::Index2Tuple,
conjB::Symbol) where {S,N₁,N₂}
spaces1 = TO.flag2op(conjA).(space.(Ref(A), pA[1]))
spaces2 = TO.flag2op(conjB).(space.(Ref(B), pB[2]))
spaces = (spaces1..., spaces2...)
cod = ProductSumSpace{S,N₁}(getindex.(Ref(spaces), pC[1]))
dom = ProductSumSpace{S,N₂}(dual.(getindex.(Ref(spaces), pC[2])))
return dom cod
end
function TO.tensorcontract_structure(pC::Index2Tuple{N₁,N₂},
A::BlockTensorMap{S}, pA::Index2Tuple, conjA::Symbol,
B::AbstractTensorMap{S}, pB::Index2Tuple,
conjB::Symbol) where {S,N₁,N₂}
spaces1 = TO.flag2op(conjA).(space.(Ref(A), pA[1]))
spaces2 = TO.flag2op(conjB).(space.(Ref(B), pB[2]))
spaces = (spaces1..., spaces2...)
cod = ProductSumSpace{S,N₁}(getindex.(Ref(spaces), pC[1]))
dom = ProductSumSpace{S,N₂}(dual.(getindex.(Ref(spaces), pC[2])))
return dom cod
end
function TO.tensorcontract_structure(pC::Index2Tuple{N₁,N₂},
A::AbstractTensorMap{S}, pA::Index2Tuple,
conjA::Symbol,
B::BlockTensorMap{S}, pB::Index2Tuple,
conjB::Symbol) where {S,N₁,N₂}
spaces1 = TO.flag2op(conjA).(space.(Ref(A), pA[1]))
spaces2 = TO.flag2op(conjB).(space.(Ref(B), pB[2]))
spaces = (spaces1..., spaces2...)
cod = ProductSumSpace{S,N₁}(getindex.(Ref(spaces), pC[1]))
dom = ProductSumSpace{S,N₂}(dual.(getindex.(Ref(spaces), pC[2])))
return dom cod
end

function TO.checkcontractible(tA::BlockTensorMap{S}, iA::Int, conjA::Symbol,
tB::BlockTensorMap{S}, iB::Int, conjB::Symbol,
label) where {S}
Expand All @@ -780,6 +820,18 @@ end
# PlanarOperations
# ----------------

function TK.BraidingTensor(V1::SumSpace{S}, V2::SumSpace{S}) where {S}
tdst = BlockTensorMap{S,2,2,TK.BraidingTensor{S,Matrix{ComplexF64}}}(undef, V2 V1, V1 V2)
for I in CartesianIndices(tdst)
if I[1] == I[4] && I[2] == I[3]
V = getsubspace(space(tdst), I)
@assert domain(V)[2] == codomain(V)[1] && domain(V)[1] == codomain(V)[2]
tdst[I] = TK.BraidingTensor(V[2], V[1])
end
end
return tdst
end

function TK.planaradd!(C::BlockTensorMap{S,N₁,N₂},
A::BlockTensorMap{S},
p::Index2Tuple{N₁,N₂},
Expand Down Expand Up @@ -955,17 +1007,7 @@ for (T1, T2) in
convert(BlockTensorMap, B), pB, conjB)
end

@eval function TO.tensorcontract_structure(pC::Index2Tuple{N₁,N₂}, A::$T1{S},
pA::Index2Tuple, conjA::Symbol,
B::$T2{S},
pB::Index2Tuple, conjB::Symbol) where {S,N₁,N₂}
spaces1 = TO.flag2op(conjA).(space.(Ref(A), pA[1]))
spaces2 = TO.flag2op(conjB).(space.(Ref(B), pB[2]))
spaces = (spaces1..., spaces2...)
cod = ProductSumSpace{S,N₁}(getindex.(Ref(spaces), pC[1]))
dom = ProductSumSpace{S,N₂}(dual.(getindex.(Ref(spaces), pC[2])))
return dom cod
end

end

if !(T1 === :BlockTensorMap && T2 === :BlockTensorMap)
Expand Down

0 comments on commit 7761009

Please sign in to comment.