From c64a17bea9bccd339f686df9b7195acfa691a219 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 11 Dec 2024 20:35:02 -0500 Subject: [PATCH] Add tests `removeunit` --- test/spaces.jl | 5 +++++ test/tensors.jl | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/test/spaces.jl b/test/spaces.jl index b37b3b7b..4039c429 100644 --- a/test/spaces.jl +++ b/test/spaces.jl @@ -279,6 +279,7 @@ println("------------------------------------") @test @constinferred(⊗(V1, V2, V3 ⊗ V4)) == P @test @constinferred(⊗(V1, V2 ⊗ V3, V4)) == P @test @constinferred(insertunit(P, 3)) == V1 * V2 * oneunit(V1) * V3 * V4 + @test @constinferred(removeunit(V1 * V2 * oneunit(V1)' * V3 * V4, 3)) == P @test fuse(V1, V2', V3) ≅ V1 ⊗ V2' ⊗ V3 @test fuse(V1, V2', V3) ≾ V1 ⊗ V2' ⊗ V3 @test fuse(V1, V2', V3) ≿ V1 ⊗ V2' ⊗ V3 @@ -339,6 +340,7 @@ println("------------------------------------") @test @constinferred(⊗(V1, V2, V3)) == P @test @constinferred(adjoint(P)) == dual(P) == V3' ⊗ V2' ⊗ V1' @test @constinferred(insertunit(P, 3; conj=true)) == V1 * V2 * oneunit(V1)' * V3 + @test P == @constinferred(removeunit(insertunit(P, 3), 3)) @test fuse(V1, V2', V3) ≅ V1 ⊗ V2' ⊗ V3 @test fuse(V1, V2', V3) ≾ V1 ⊗ V2' ⊗ V3 ≾ fuse(V1 ⊗ V2' ⊗ V3) @test fuse(V1, V2') ⊗ V3 ≾ V1 ⊗ V2' ⊗ V3 @@ -420,11 +422,14 @@ println("------------------------------------") @test permute(W, ((2, 4, 5), (3, 1))) == (V2 ⊗ V4' ⊗ V5' ← V3 ⊗ V1') @test (V1 ⊗ V2 ← V1 ⊗ V2) == @constinferred TensorKit.compose(W, W') @test @constinferred(insertunit(W)) == (V1 ⊗ V2 ← V3 ⊗ V4 ⊗ V5 ⊗ oneunit(V5)) + @test @constinferred(removeunit(insertunit(W), numind(W) + 1)) == W @test @constinferred(insertunit(W; conj=true)) == (V1 ⊗ V2 ← V3 ⊗ V4 ⊗ V5 ⊗ oneunit(V5)') @test @constinferred(insertunit(W, 1)) == (oneunit(V1) ⊗ V1 ⊗ V2 ← V3 ⊗ V4 ⊗ V5) @test @constinferred(insertunit(W, 3)) == (V1 ⊗ V2 ⊗ oneunit(V1) ← V3 ⊗ V4 ⊗ V5) + @test @constinferred(removeunit(insertunit(W, 3))) == W @test @constinferred(insertunit(W, 3; preferdomain=true)) == (V1 ⊗ V2 ← oneunit(V1) ⊗ V3 ⊗ V4 ⊗ V5) + @test @constinferred(removeunit(insertunit(W, 3; preferdomain=true), 3)) == W end end diff --git a/test/tensors.jl b/test/tensors.jl index 18d9aba2..15ea076e 100644 --- a/test/tensors.jl +++ b/test/tensors.jl @@ -182,21 +182,25 @@ for V in spacelist @test space(t2) == insertunit(space(t)) @test scalartype(t2) === T @test t.data === t2.data + @test @constinferred(removeunit(t2, numind(t2))) == t t3 = @constinferred insertunit(t; copy=true) @test t.data !== t3.data for (c, b) in blocks(t) @test b == block(t3, c) end + @test @constinferred(removeunit(t3, numind(t3))) == t t4 = @constinferred insertunit(t, 4; dual=true) @test numin(t4) == numin(t) && numout(t4) == numout(t) + 1 for (c, b) in blocks(t) @test b == block(t4, c) end + @test @constinferred(removeunit(t4, 4)) == t t5 = @constinferred insertunit(t, 4; dual=true) @test numin(t5) == numin(t) + 1 && numout(t5) == numout(t) for (c, b) in blocks(t) @test b == block(t5, c) end + @test @constinferred(removeunit(t5, 4)) == t end end if hasfusiontensor(I)