Skip to content

Commit

Permalink
Remove use of span, move some newly introduced functions into ttn_svd.
Browse files Browse the repository at this point in the history
  • Loading branch information
Benedikt Kloss committed Apr 15, 2024
1 parent 283e9bd commit cba04e7
Showing 1 changed file with 37 additions and 56 deletions.
93 changes: 37 additions & 56 deletions src/treetensornetworks/opsum_to_ttn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,65 +6,30 @@ using ITensors.NDTensors: Block, maxdim, nblocks, nnzblocks
using ITensors.Ops: Op, OpSum
using NamedGraphs: degrees, is_leaf, vertex_path
using StaticArrays: MVector

using DataGraphs: AbstractDataGraph
using NamedGraphs: boundary_edges
# convert ITensors.OpSum to TreeTensorNetwork

#
# Utility methods
#

# determine 'support' of product operator on tree graph
function span(t::Scaled{C,Prod{Op}}, g::AbstractGraph) where {C}
spn = Set{eltype(g)}()
nterms = length(t)
nterms == 1 && return Set([ITensors.site(t[1])])
for i in 1:nterms, j in (i + 1):nterms
path = Set(vertex_path(g, ITensors.site(t[i]), ITensors.site(t[j])))
spn = union!(spn, path)
end
return spn
function align_edges(edges, reference_edges)
return intersect(Iterators.flatten(zip(edges, reverse.(edges))), reference_edges)
end

# determine whether an operator string crosses a given graph vertex
function crosses_vertex(t::Scaled{C,Prod{Op}}, g::AbstractGraph, v) where {C}
return v span(t, g)
function align_and_reorder_edges(edges, reference_edges)
return intersect(reference_edges, align_edges(edges, reference_edges))
end

function align_edges(edges, reference_edges)
return intersect(reference_edges, Iterators.flatten((edges, reverse.(edges))))
end

# return a dict from vertices `w` of `g`, except for `v`, to the incident edge of `v`
# which lies in edge_path(g,w,v)
function vertices_to_incident_edges_dict(g::AbstractGraph, v, incident_edges)
#split graph into subtrees by removing vertex v
_g = copy(underlying_graph(g))
function split_at_vertex(g::AbstractGraph, v)
_g = copy(g)
rem_vertex!(_g, v)
subgraphs = Set.(connected_components(_g))

#for each incident edge, store the vertex that's not `v`
vs = vertextype(g)[]
for e in incident_edges
push!(vs, only(setdiff([src(e), dst(e)], [v])))
end

#return a Dictionary from vertices to incident_edges
_vs = vertextype(g)[]
_es = edgetype(g)[]
for (e, v) in zip(incident_edges, vs)
for i in eachindex(subgraphs)
if v in subgraphs[i]
append!(_vs, subgraphs[i])
append!(_es, fill(e, length(subgraphs[i])))
deleteat!(subgraphs, i)
break
end
end
end
@assert isempty(subgraphs)
return Dict(zip(_vs, _es))
return Set.(connected_components(_g))
end

split_at_vertex(g::AbstractDataGraph, v) = split_at_vertex(underlying_graph(g), v)

#
# Tree adaptations of functionalities in ITensors.jl/src/physics/autompo/opsum_to_mpo.jl
#
Expand Down Expand Up @@ -143,33 +108,49 @@ function ttn_svd(
site_coef_done = Prod{Op}[] # list of terms for which the coefficient has been added to a site factor
# temporary symbolic representation of TTN Hamiltonian
tempTTN = Dict(v => QNArrElem{Scaled{coefficient_type,Prod{Op}},degrees[v]}[] for v in vs)
#ToDo: precompute span of each term and store
# compute span of each term
spans = Dict{eltype(os),Set{vertextype_sites}}()
for term in os
spans[term] = span(term, sites)
end

# build compressed finite state machine representation
for v in vs
# for every vertex, find all edges that contain this vertex
edges = align_edges(incident_edges(sites, v), es)
edges = align_and_reorder_edges(incident_edges(sites, v), es)

# use the corresponding ordering as index order for tensor elements at this site
dim_in = findfirst(e -> dst(e) == v, edges)
edge_in = (isnothing(dim_in) ? [] : edges[dim_in])
dims_out = findall(e -> src(e) == v, edges)
edges_out = edges[dims_out]

which_incident_edge = vertices_to_incident_edges_dict(sites, v, edges)
# for every site w except v, determine the incident edge to v that lies
# in the edge_path(w,v)
subgraphs = split_at_vertex(sites, v)
_boundary_edges = align_edges(
[only(boundary_edges(underlying_graph(sites), subgraph)) for subgraph in subgraphs],
edges,
)
which_incident_edge = Dict(
Iterators.flatten([
subgraphs[i] .=> ((_boundary_edges[i]),) for i in eachindex(subgraphs)
]),
)

# sanity check, leaves only have single incoming or outgoing edge
@assert !isempty(dims_out) || !isnothing(dim_in)
(isempty(dims_out) || isnothing(dim_in)) && @assert is_leaf(sites, v)

for term in os
# loop over OpSum and pick out terms that act on current vertex

v in spans[term] || continue
factors = ITensors.terms(term)
if v in ITensors.site.(factors)
crosses_vertex = true
else
crosses_vertex =
!isone(
length(Set([which_incident_edge[site] for site in ITensors.site.(factors)]))
)
end
#if term doesn't cross vertex, skip it
crosses_vertex || continue

# filter out factor that acts on current vertex
onsite = filter(t -> (ITensors.site(t) == v), factors)
Expand Down Expand Up @@ -290,7 +271,7 @@ function ttn_svd(
for v in vs
# redo the whole thing like before
# ToDo: use neighborhood instead of going through all edges, see above
edges = align_edges(incident_edges(sites, v), es)
edges = align_and_reorder_edges(incident_edges(sites, v), es)
dim_in = findfirst(e -> dst(e) == v, edges)
dims_out = findall(e -> src(e) == v, edges)
# slice isometries at this vertex
Expand Down

0 comments on commit cba04e7

Please sign in to comment.