diff --git a/src/fusiontrees/fusiontrees.jl b/src/fusiontrees/fusiontrees.jl index 4e168b86..f6ebe84f 100644 --- a/src/fusiontrees/fusiontrees.jl +++ b/src/fusiontrees/fusiontrees.jl @@ -207,6 +207,20 @@ function Base.convert(A::Type{<:AbstractArray}, f::FusionTree{I,N}) where {I,N} Ctail, ((1,), Base.tail(trivialtuple)), :N, true, false) end +# TODO: is this piracy? +function Base.convert(A::Type{<:AbstractArray}, + (f₁, f₂)::Tuple{FusionTree{I},FusionTree{I}}) where {I} + F₁ = convert(A, f₁) + F₂ = convert(A, f₂) + sz1 = size(F₁) + sz2 = size(F₂) + d1 = TupleTools.front(sz1) + d2 = TupleTools.front(sz2) + + return reshape(reshape(F₁, TupleTools.prod(d1), sz1[end]) * + reshape(F₂, TupleTools.prod(d2), sz2[end])', (d1..., d2...)) +end + # Show methods function Base.show(io::IO, t::FusionTree{I,N,M,K,Nothing}) where {I<:Sector,N,M,K} return print(IOContext(io, :typeinfo => I), "FusionTree{", type_repr(I), "}(", diff --git a/src/tensors/abstracttensor.jl b/src/tensors/abstracttensor.jl index 55195e4f..38997c94 100644 --- a/src/tensors/abstracttensor.jl +++ b/src/tensors/abstracttensor.jl @@ -267,14 +267,7 @@ function Base.convert(::Type{Array}, t::AbstractTensorMap{S,N₁,N₂}) where {S dom = domain(t) local A for (f₁, f₂) in fusiontrees(t) - F₁ = convert(Array, f₁) - F₂ = convert(Array, f₂) - sz1 = size(F₁) - sz2 = size(F₂) - d1 = TupleTools.front(sz1) - d2 = TupleTools.front(sz2) - F = reshape(reshape(F₁, TupleTools.prod(d1), sz1[end]) * - reshape(F₂, TupleTools.prod(d2), sz2[end])', (d1..., d2...)) + F = convert(Array, (f₁, f₂)) if !(@isdefined A) if eltype(F) <: Complex T = complex(float(scalartype(t)))