From 1640da313500e69f786d10c4bdb3321307b9e3e9 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Mon, 10 Jun 2024 15:56:54 +0200 Subject: [PATCH] Add support for diag/diagm --- src/tensors/linalg.jl | 24 ++++++++++++++++++++++++ test/tensors.jl | 8 ++++++++ 2 files changed, 32 insertions(+) diff --git a/src/tensors/linalg.jl b/src/tensors/linalg.jl index 46bbb285..81adfa3a 100644 --- a/src/tensors/linalg.jl +++ b/src/tensors/linalg.jl @@ -153,6 +153,30 @@ function isometry(::Type{A}, return t end +# Diagonal tensors +# ---------------- +# TODO: consider adding a specialised DiagonalTensorMap type +function LinearAlgebra.diag(t::AbstractTensorMap) + if sectortype(t) === Trivial + return LinearAlgebra.diag(block(t, Trivial())) + else + return SectorDict(c => LinearAlgebra.diag(b) for (c, b) in blocks(t)) + end +end + +function LinearAlgebra.diagm(codom::VectorSpace, dom::VectorSpace, v::AbstractVector) + sp = codom ← dom + @assert sectortype(sp) === Trivial + return TensorMap(LinearAlgebra.diagm(dim(codom), dim(dom), v), sp) +end +function LinearAlgebra.diagm(codom::VectorSpace, dom::VectorSpace, v::SectorDict) + return TensorMap(SectorDict(c => LinearAlgebra.diagm(blockdim(codom, c), + blockdim(dom, c), b) + for (c, b) in v), codom ← dom) +end + +LinearAlgebra.isdiag(t::AbstractTensorMap) = all(LinearAlgebra.isdiag, values(blocks(t))) + # In-place methods #------------------ # Wrapping the blocks in a StridedView enables multithreading if JULIA_NUM_THREADS > 1 diff --git a/test/tensors.jl b/test/tensors.jl index 10559b85..dcfa9c59 100644 --- a/test/tensors.jl +++ b/test/tensors.jl @@ -188,6 +188,14 @@ for V in spacelist @test Base.promote_typeof(t, tc) == typeof(tc) @test Base.promote_typeof(tc, t) == typeof(tc + t) end + @timedtestset "diag/diagm" begin + W = V1 ⊗ V2 ⊗ V3 ← V4 ⊗ V5 + t = TensorMap(randn, ComplexF64, W) + d = LinearAlgebra.diag(t) + D = LinearAlgebra.diagm(codomain(t), domain(t), d) + @test LinearAlgebra.isdiag(D) + @test LinearAlgebra.diag(D) == d + end @timedtestset "Permutations: test via inner product invariance" begin W = V1 ⊗ V2 ⊗ V3 ⊗ V4 ⊗ V5 t = Tensor(rand, ComplexF64, W)