Skip to content

Commit

Permalink
Refactor project_symmetric! into separate function
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Jul 2, 2024
1 parent 5003daa commit 7cc0d32
Showing 1 changed file with 25 additions and 20 deletions.
45 changes: 25 additions & 20 deletions src/tensors/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -357,41 +357,46 @@ function TensorMap(data::DenseArray, codom::ProductSpace{S,N₁}, dom::ProductSp
return TensorMap{S,N₁,N₂,Trivial,A,Nothing,Nothing}(data2, codom, dom)
end

t = TensorMap(zeros, eltype(data), codom, dom)
t = TensorMap(undef, eltype(data), codom, dom)
project_symmetric!(t, data)

if !isapprox(data, convert(Array, t); atol=tol)
throw(ArgumentError("Data has non-zero elements at incompatible positions"))
end

return t
end

"""
project_symmetric!(t::TensorMap, data::DenseArray) -> TensorMap
Project the data from a dense array `data` into the tensor map `t`. This function discards
any data that does not fit the symmetry structure of `t`.
"""
function project_symmetric!(t::TensorMap, data::DenseArray)
if sectortype(t) === Trivial
copy!(t.data, reshape(data, size(t.data)))
return t
end

for (f₁, f₂) in fusiontrees(t)
c = f₁.coupled
d = inv(dim(c))

# fusiontree part
F₁ = convert(Array, f₁)
F₂ = convert(Array, f₂)
sz1 = size(F₁)
sz2 = size(F₂)
d1 = TupleTools.front(sz1)
d2 = TupleTools.front(sz2)
F = reshape(reshape(F₁, TupleTools.prod(d1), sz1[end]) *
reshape(F₂, TupleTools.prod(d2), sz2[end])', (d1..., d2...))

# data part
F = convert(Array, (f₁, f₂))
b = zeros(eltype(data), dims(codomain(t), f₁.uncoupled)...,
dims(domain(t), f₂.uncoupled)...)
szbF = _interleave(size(b), size(F))
dataslice = sreshape(StridedView(data)[axes(codom, f₁.uncoupled)...,
axes(dom, f₂.uncoupled)...],
szbF)
axes(dom, f₂.uncoupled)...], szbF)
# project (can this be done in one go?)
for k in eachindex(b)
b[k] = 1
projector = _kron(b, F)
projector = _kron(b, F) # probably possible to re-use memory
t[f₁, f₂][k] = dot(projector, dataslice) * d
b[k] = 0
end
end

if !isapprox(data, convert(Array, t); atol=tol)
throw(ArgumentError("Data has non-zero elements at incompatible positions"))
end

return t
end

Expand Down

0 comments on commit 7cc0d32

Please sign in to comment.