Skip to content

Commit

Permalink
Better sparse constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Nov 21, 2024
1 parent 033a523 commit 289f78a
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions src/tensors/sparseblocktensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,15 @@ end
Construct a sparse blocktensor with entries compatible with type `T` and space `W`.
By default, the tensor will be empty, but nonzero entries can be specified by passing a tuple of indices `nonzero_inds`.
"""
spzeros(W::TensorMapSumSpace) = spzeros(Float64, W)
spzeros(W::TensorMapSumSpace, args...) = spzeros(Float64, W, args...)
function spzeros(T::Type, cod::TensorSumSpace, dom::TensorSumSpace=one(cod), args...)
return spzeros(T, cod dom, args...)
end
function spzeros(
T::Type, cod::TensorSumSpace, nonzero_inds::AbstractVector{<:CartesianIndex}
)
return spzeros(T, cod, one(cod), nonzero_inds)
end
function spzeros(
::Type{T}, W::TensorMapSumSpace, nonzero_inds=CartesianIndex{length(W)}[]
) where {T}
Expand All @@ -116,9 +124,17 @@ end
Construct a sparse blocktensor with entries compatible with type `T` and space `W`.
Each entry is nonzero with probability `p`.
"""
sprand(V::VectorSpace, p::Real) = sprand(Random.default_rng(), Float64, V, p)
sprand(rng::Random.AbstractRNG, V::VectorSpace, p::Real) = sprand(rng, Float64, V, p)
sprand(T::Type, V::VectorSpace, p::Real) = sprand(Random.default_rng(), T, V, p)
sprand(V::TensorMapSumSpace, p::Real) = sprand(Random.default_rng(), Float64, V, p)
sprand(rng::Random.AbstractRNG, V::TensorMapSumSpace, p::Real) = sprand(rng, Float64, V, p)
sprand(T::Type, V::TensorMapSumSpace, p::Real) = sprand(Random.default_rng(), T, V, p)
sprand(V::TensorSumSpace, p::Real) = sprand(Random.default_rng(), Float64, V one(V), p)
function sprand(rng::Random.AbstractRNG, V::TensorSumSpace, p::Real)
return sprand(rng, Float64, V one(V), p)
end
sprand(T::Type, V::TensorSumSpace, p::Real) = sprand(Random.default_rng(), T, V one(V), p)
function sprand(rng::Random.AbstractRNG, T::Type, V::TensorSumSpace, p::Real)
return sprand(rng, T, V one(V), p)
end
function sprand(
rng::Random.AbstractRNG, ::Type{T}, V::TensorMapSumSpace, p::Real
) where {T<:Number}
Expand Down

0 comments on commit 289f78a

Please sign in to comment.