Skip to content

Commit

Permalink
Improvements for similar
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Nov 14, 2024
1 parent 6abebda commit 6973205
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
19 changes: 18 additions & 1 deletion src/tensors/abstractblocktensor/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,20 @@ end
# end

Base.similar(t::AbstractBlockTensorMap) = similar(t, eltype(t), space(t))
Base.similar(t::AbstractBlockTensorMap, P::TensorMapSumSpace) = similar(t, eltype(t), P)

# make sure tensormap specializations are not used for sumspaces:
function Base.similar(
t::AbstractTensorMap, ::Type{TorA}, P::TensorMapSumSpace{S}
) where {S,TorA}
if TorA <: AbstractTensorMap
return BlockTensorMap{TorA}(undef_blocks, P)
# might need to change the type of the tensor map to account for the new space
TT = similar_tensormaptype(TorA, P)
return if issparse(t)
SparseBlockTensorMap{TT}(undef, P)
else
BlockTensorMap{TT}(undef, P)
end
elseif TorA <: Number
T = TorA
A = TensorKit.similarstoragetype(t, T)
Expand All @@ -212,6 +219,16 @@ function Base.similar(
return BlockTensorMap{TT}(undef, P)
end

function similar_tensormaptype(
T::Type{<:AbstractTensorMap}, P::TensorMapSumSpace{S}
) where {S}
if isconcretetype(T)
return tensormaptype(S, numout(P), numin(P), storagetype(T))
else
return AbstractTensorMap{scalartype(T),S,numout(P),numin(P)}
end
end

# implementation in type domain
function Base.similar(::Type{T}, P::TensorMapSumSpace) where {T<:AbstractBlockTensorMap}
return T(undef_blocks, P)
Expand Down
12 changes: 0 additions & 12 deletions src/tensors/sparseblocktensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,6 @@ function sprand(::Type{T}, V::VectorSpace, p::Real) where {T<:Number}
return sprand(T, V one(V), p)
end

# specific implementation for SparseBlockTensorMap with Sumspace -> returns `SparseBlockTensorMap`
function Base.similar(
::SparseBlockTensorMap{TT}, TorA::Type, space::TensorMapSumSpace{S}
) where {TT,S}
if TorA <: AbstractTensorMap
TT′ = TorA
else
TT′ = tensormaptype(S, numout(space), numin(space), TorA)
end
return SparseBlockTensorMap{TT′}(undef_blocks, space)
end

# Properties
# ----------
TK.space(t::SparseBlockTensorMap) = t.space
Expand Down

0 comments on commit 6973205

Please sign in to comment.