Skip to content

Commit

Permalink
Add support for diag/diagm
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Jun 11, 2024
1 parent e57e6f5 commit 1640da3
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
24 changes: 24 additions & 0 deletions src/tensors/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,30 @@ function isometry(::Type{A},
return t
end

# Diagonal tensors
# ----------------
# TODO: consider adding a specialised DiagonalTensorMap type
function LinearAlgebra.diag(t::AbstractTensorMap)
if sectortype(t) === Trivial
return LinearAlgebra.diag(block(t, Trivial()))
else
return SectorDict(c => LinearAlgebra.diag(b) for (c, b) in blocks(t))
end
end

function LinearAlgebra.diagm(codom::VectorSpace, dom::VectorSpace, v::AbstractVector)
sp = codom dom
@assert sectortype(sp) === Trivial
return TensorMap(LinearAlgebra.diagm(dim(codom), dim(dom), v), sp)
end
function LinearAlgebra.diagm(codom::VectorSpace, dom::VectorSpace, v::SectorDict)
return TensorMap(SectorDict(c => LinearAlgebra.diagm(blockdim(codom, c),
blockdim(dom, c), b)
for (c, b) in v), codom dom)
end

LinearAlgebra.isdiag(t::AbstractTensorMap) = all(LinearAlgebra.isdiag, values(blocks(t)))

# In-place methods
#------------------
# Wrapping the blocks in a StridedView enables multithreading if JULIA_NUM_THREADS > 1
Expand Down
8 changes: 8 additions & 0 deletions test/tensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@ for V in spacelist
@test Base.promote_typeof(t, tc) == typeof(tc)
@test Base.promote_typeof(tc, t) == typeof(tc + t)
end
@timedtestset "diag/diagm" begin
W = V1 V2 V3 V4 V5
t = TensorMap(randn, ComplexF64, W)
d = LinearAlgebra.diag(t)
D = LinearAlgebra.diagm(codomain(t), domain(t), d)
@test LinearAlgebra.isdiag(D)
@test LinearAlgebra.diag(D) == d
end
@timedtestset "Permutations: test via inner product invariance" begin
W = V1 V2 V3 V4 V5
t = Tensor(rand, ComplexF64, W)
Expand Down

0 comments on commit 1640da3

Please sign in to comment.