diff --git a/src/spaces/homspace.jl b/src/spaces/homspace.jl index 5f6d62a1..c1c4e6af 100644 --- a/src/spaces/homspace.jl +++ b/src/spaces/homspace.jl @@ -124,7 +124,13 @@ function permute(W::HomSpace{S}, (p₁, p₂)::Index2Tuple{N₁,N₂}) where {S, return cod ← dom end -function Base.:*(W::HomSpace{S}, V::HomSpace{S}) where {S} +""" + compose(W::HomSpace, V::HomSpace) + +Obtain the HomSpace that is obtained from composing the morphisms in `W` and `V`. For this +to be possible, the domain of `W` must match the codomain of `V`. +""" +function compose(W::HomSpace{S}, V::HomSpace{S}) where {S} domain(W) == codomain(V) || throw(SpaceMismatch("$(domain(W)) ≠ $(codomain(V))")) return HomSpace(codomain(W), domain(V)) end diff --git a/src/tensors/linalg.jl b/src/tensors/linalg.jl index 7c75a932..e00b1ee7 100644 --- a/src/tensors/linalg.jl +++ b/src/tensors/linalg.jl @@ -20,7 +20,7 @@ LinearAlgebra.normalize(t::AbstractTensorMap, p::Real=2) = scale(t, inv(norm(t, function Base.:*(t1::AbstractTensorMap, t2::AbstractTensorMap) return mul!(similar(t1, promote_type(scalartype(t1), scalartype(t2)), - space(t1) * space(t2)), t1, t2) + compose(space(t1), space(t2))), t1, t2) end Base.exp(t::AbstractTensorMap) = exp!(copy(t)) function Base.:^(t::AbstractTensorMap, p::Integer) @@ -242,7 +242,7 @@ end function LinearAlgebra.mul!(tC::AbstractTensorMap, tA::AbstractTensorMap, tB::AbstractTensorMap, α=true, β=false) - space(tA) * space(tB) == space(tC) || + compose(space(tA), space(tB)) == space(tC) || throw(SpaceMismatch("$(space(tC)) ≠ $(space(tA)) * $(space(tB))")) for c in blocksectors(tC) diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index ed084a4d..a2b6f171 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -128,7 +128,7 @@ function TO.tensorcontract_structure(pC::Index2Tuple{N₁,N₂}, conjB) where {S,N₁,N₂} sA = TO.tensoradd_structure(pA, A, conjA) sB = TO.tensoradd_structure(pB, B, conjB) - return permute(sA * sB, pC) + return permute(compose(sA, sB), pC) end function TO.checkcontractible(tA::AbstractTensorMap{S}, iA::Int, conjA::Symbol, diff --git a/test/spaces.jl b/test/spaces.jl index 48b9e51d..42f4b922 100644 --- a/test/spaces.jl +++ b/test/spaces.jl @@ -412,6 +412,6 @@ println("------------------------------------") @test W == deepcopy(W) @test W == @constinferred permute(W, ((1, 2), (3, 4, 5))) @test permute(W, ((2, 4, 5), (3, 1))) == (V2 ⊗ V4' ⊗ V5' ← V3 ⊗ V1') - @test (V1 ⊗ V2 ← V1 ⊗ V2) == @constinferred W * W' + @test (V1 ⊗ V2 ← V1 ⊗ V2) == @constinferred compose(W, W') end end