diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index cf4f3d7b..55c07ec0 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -357,41 +357,46 @@ function TensorMap(data::DenseArray, codom::ProductSpace{S,N₁}, dom::ProductSp return TensorMap{S,N₁,N₂,Trivial,A,Nothing,Nothing}(data2, codom, dom) end - t = TensorMap(zeros, eltype(data), codom, dom) + t = TensorMap(undef, eltype(data), codom, dom) + project_symmetric!(t, data) + + if !isapprox(data, convert(Array, t); atol=tol) + throw(ArgumentError("Data has non-zero elements at incompatible positions")) + end + + return t +end + +""" + project_symmetric!(t::TensorMap, data::DenseArray) -> TensorMap + +Project the data from a dense array `data` into the tensor map `t`. This function discards +any data that does not fit the symmetry structure of `t`. +""" +function project_symmetric!(t::TensorMap, data::DenseArray) + if sectortype(t) === Trivial + copy!(t.data, reshape(data, size(t.data))) + return t + end + for (f₁, f₂) in fusiontrees(t) c = f₁.coupled d = inv(dim(c)) - - # fusiontree part - F₁ = convert(Array, f₁) - F₂ = convert(Array, f₂) - sz1 = size(F₁) - sz2 = size(F₂) - d1 = TupleTools.front(sz1) - d2 = TupleTools.front(sz2) - F = reshape(reshape(F₁, TupleTools.prod(d1), sz1[end]) * - reshape(F₂, TupleTools.prod(d2), sz2[end])', (d1..., d2...)) - - # data part + F = convert(Array, (f₁, f₂)) b = zeros(eltype(data), dims(codomain(t), f₁.uncoupled)..., dims(domain(t), f₂.uncoupled)...) szbF = _interleave(size(b), size(F)) dataslice = sreshape(StridedView(data)[axes(codom, f₁.uncoupled)..., - axes(dom, f₂.uncoupled)...], - szbF) + axes(dom, f₂.uncoupled)...], szbF) # project (can this be done in one go?) for k in eachindex(b) b[k] = 1 - projector = _kron(b, F) + projector = _kron(b, F) # probably possible to re-use memory t[f₁, f₂][k] = dot(projector, dataslice) * d b[k] = 0 end end - if !isapprox(data, convert(Array, t); atol=tol) - throw(ArgumentError("Data has non-zero elements at incompatible positions")) - end - return t end