From ad17c78e36c3c9ece6167d0c9c92a43a3630e7f4 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 14 Nov 2024 12:01:00 -0500 Subject: [PATCH] Improvements for `similar` --- .../abstractblocktensor/abstractarray.jl | 23 ++++++++++++++++--- src/tensors/sparseblocktensor.jl | 12 ---------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/tensors/abstractblocktensor/abstractarray.jl b/src/tensors/abstractblocktensor/abstractarray.jl index 6690f14..a12f255 100644 --- a/src/tensors/abstractblocktensor/abstractarray.jl +++ b/src/tensors/abstractblocktensor/abstractarray.jl @@ -190,13 +190,20 @@ end # end Base.similar(t::AbstractBlockTensorMap) = similar(t, eltype(t), space(t)) +Base.similar(t::AbstractBlockTensorMap, P::TensorMapSumSpace) = similar(t, eltype(t), P) # make sure tensormap specializations are not used for sumspaces: function Base.similar( t::AbstractTensorMap, ::Type{TorA}, P::TensorMapSumSpace{S} ) where {S,TorA} if TorA <: AbstractTensorMap - return BlockTensorMap{TorA}(undef_blocks, P) + # might need to change the type of the tensor map to account for the new space + TT = similar_tensormaptype(TorA, P) + return if issparse(t) + SparseBlockTensorMap{TT}(undef, P) + else + BlockTensorMap{TT}(undef, P) + end elseif TorA <: Number T = TorA A = TensorKit.similarstoragetype(t, T) @@ -209,10 +216,20 @@ function Base.similar( N₁ = length(codomain(P)) N₂ = length(domain(P)) TT = TensorMap{T,S,N₁,N₂,A} - return BlockTensorMap{TT}(undef, P) + return issparse(t) ? SparseBlockTensorMap{TT}(undef, P) : BlockTensorMap{TT}(undef, P) +end + +function similar_tensormaptype( + T::Type{<:AbstractTensorMap}, P::TensorMapSumSpace{S} +) where {S} + if isconcretetype(T) + return tensormaptype(S, numout(P), numin(P), storagetype(T)) + else + return AbstractTensorMap{scalartype(T),S,numout(P),numin(P)} + end end # implementation in type domain function Base.similar(::Type{T}, P::TensorMapSumSpace) where {T<:AbstractBlockTensorMap} - return T(undef_blocks, P) + return T(undef, P) end diff --git a/src/tensors/sparseblocktensor.jl b/src/tensors/sparseblocktensor.jl index 4d13d88..d2a684e 100644 --- a/src/tensors/sparseblocktensor.jl +++ b/src/tensors/sparseblocktensor.jl @@ -104,18 +104,6 @@ function sprand(::Type{T}, V::VectorSpace, p::Real) where {T<:Number} return sprand(T, V ← one(V), p) end -# specific implementation for SparseBlockTensorMap with Sumspace -> returns `SparseBlockTensorMap` -function Base.similar( - ::SparseBlockTensorMap{TT}, TorA::Type, space::TensorMapSumSpace{S} -) where {TT,S} - if TorA <: AbstractTensorMap - TT′ = TorA - else - TT′ = tensormaptype(S, numout(space), numin(space), TorA) - end - return SparseBlockTensorMap{TT′}(undef_blocks, space) -end - # Properties # ---------- TK.space(t::SparseBlockTensorMap) = t.space