Skip to content

Commit

Permalink
Add more BlockTensorMap constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Nov 10, 2024
1 parent fd6dece commit f140cf0
Showing 1 changed file with 45 additions and 9 deletions.
54 changes: 45 additions & 9 deletions src/tensors/blocktensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,51 @@ function BlockTensorMap{TT}(
return BlockTensorMap{TT}(data, codom dom)
end

function BlockTensorMap(
f::Union{UndefInitializer,UndefBlocksInitializer}, space::TensorMapSumSpace{S,N₁,N₂}
) where {S,N₁,N₂}
TT = tensormaptype(S, N₁, N₂, Float64)
return BlockTensorMap{TT}(f, space)
end
function BlockTensorMap(
f::Union{UndefInitializer,UndefBlocksInitializer},
codom::ProductSumSpace,
dom::ProductSumSpace,
)
return BlockTensorMap(f, codom dom)
end

function BlockTensorMap(
data::Array{TT}, space::TensorMapSumSpace{S,N₁,N₂}
) where {S,N₁,N₂,TT<:AbstractTensorMap{<:Any,S,N₁,N₂}}
return BlockTensorMap{TT}(data, space)
end
function BlockTensorMap(
data::Array{<:AbstractTensorMap}, codom::ProductSumSpace, dom::ProductSumSpace
)
return BlockTensorMap(data, codom dom)
end

# AbstractBlockTensorMap -> BlockTensorMap
function BlockTensorMap(t::AbstractBlockTensorMap)
t isa BlockTensorMap && return t # TODO: should this copy?
tdst = BlockTensorMap{eltype(t)}(undef_blocks, space(t))
for I in eachindex(t)
tdst[I] = t[I]
end
return tdst
end

# AbstractTensorMap -> BlockTensorMap
function BlockTensorMap(t::AbstractTensorMap, space::TensorMapSumSpace)
TT = tensormaptype(spacetype(t), numout(t), numin(t), storagetype(t))
tdst = BlockTensorMap{TT}(undef, space)
for (f₁, f₂) in fusiontrees(tdst)
tdst[f₁, f₂] .= t[f₁, f₂]
end
return tdst
end

# Convenience constructors
# ------------------------
for (fname, felt) in ((:zeros, :zero), (:ones, :one))
Expand Down Expand Up @@ -85,15 +130,6 @@ for randfun in (:rand, :randn, :randexp)
end
end

function BlockTensorMap(t::AbstractBlockTensorMap)
t isa BlockTensorMap && return t # TODO: should this copy?
tdst = BlockTensorMap{eltype(t)}(undef_blocks, codomain(t), domain(t))
for I in eachindex(t)
tdst[I] = t[I]
end
return tdst
end

# Properties
# ----------
Base.eltype(::Type{<:BlockTensorMap{TT}}) where {TT} = TT
Expand Down

0 comments on commit f140cf0

Please sign in to comment.