From 69732056c3b519329e17bee0199128323736e013 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 14 Nov 2024 12:01:00 -0500 Subject: [PATCH] Improvements for `similar` --- .../abstractblocktensor/abstractarray.jl | 19 ++++++++++++++++++- src/tensors/sparseblocktensor.jl | 12 ------------ 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/tensors/abstractblocktensor/abstractarray.jl b/src/tensors/abstractblocktensor/abstractarray.jl index 6690f14..b34a3e1 100644 --- a/src/tensors/abstractblocktensor/abstractarray.jl +++ b/src/tensors/abstractblocktensor/abstractarray.jl @@ -190,13 +190,20 @@ end # end Base.similar(t::AbstractBlockTensorMap) = similar(t, eltype(t), space(t)) +Base.similar(t::AbstractBlockTensorMap, P::TensorMapSumSpace) = similar(t, eltype(t), P) # make sure tensormap specializations are not used for sumspaces: function Base.similar( t::AbstractTensorMap, ::Type{TorA}, P::TensorMapSumSpace{S} ) where {S,TorA} if TorA <: AbstractTensorMap - return BlockTensorMap{TorA}(undef_blocks, P) + # might need to change the type of the tensor map to account for the new space + TT = similar_tensormaptype(TorA, P) + return if issparse(t) + SparseBlockTensorMap{TT}(undef, P) + else + BlockTensorMap{TT}(undef, P) + end elseif TorA <: Number T = TorA A = TensorKit.similarstoragetype(t, T) @@ -212,6 +219,16 @@ function Base.similar( return BlockTensorMap{TT}(undef, P) end +function similar_tensormaptype( + T::Type{<:AbstractTensorMap}, P::TensorMapSumSpace{S} +) where {S} + if isconcretetype(T) + return tensormaptype(S, numout(P), numin(P), storagetype(T)) + else + return AbstractTensorMap{scalartype(T),S,numout(P),numin(P)} + end +end + # implementation in type domain function Base.similar(::Type{T}, P::TensorMapSumSpace) where {T<:AbstractBlockTensorMap} return T(undef_blocks, P) diff --git a/src/tensors/sparseblocktensor.jl b/src/tensors/sparseblocktensor.jl index 4d13d88..d2a684e 100644 --- a/src/tensors/sparseblocktensor.jl +++ b/src/tensors/sparseblocktensor.jl @@ -104,18 +104,6 @@ function sprand(::Type{T}, V::VectorSpace, p::Real) where {T<:Number} return sprand(T, V ← one(V), p) end -# specific implementation for SparseBlockTensorMap with Sumspace -> returns `SparseBlockTensorMap` -function Base.similar( - ::SparseBlockTensorMap{TT}, TorA::Type, space::TensorMapSumSpace{S} -) where {TT,S} - if TorA <: AbstractTensorMap - TT′ = TorA - else - TT′ = tensormaptype(S, numout(space), numin(space), TorA) - end - return SparseBlockTensorMap{TT′}(undef_blocks, space) -end - # Properties # ---------- TK.space(t::SparseBlockTensorMap) = t.space