From 1e6ec2b0f3d0d19ed68687ba3831ee8438dd78ef Mon Sep 17 00:00:00 2001 From: Joseph Tindall Date: Thu, 26 Oct 2023 16:11:01 -0400 Subject: [PATCH] Better specification of update sequence for BP --- examples/belief_propagation/bpsequences.jl | 79 +++++++++--- src/beliefpropagation/beliefpropagation.jl | 114 ++++++++++++++---- .../sqrt_beliefpropagation.jl | 94 +++++++++------ src/gauging.jl | 14 ++- src/utils.jl | 6 +- test/test_belief_propagation.jl | 26 ++-- 6 files changed, 227 insertions(+), 106 deletions(-) diff --git a/examples/belief_propagation/bpsequences.jl b/examples/belief_propagation/bpsequences.jl index 6a9797e0..5691f6ae 100644 --- a/examples/belief_propagation/bpsequences.jl +++ b/examples/belief_propagation/bpsequences.jl @@ -14,8 +14,7 @@ using ITensorNetworks: nested_graph_leaf_vertices function main() - - g = named_comb_tree((6,6)) + g = named_comb_tree((6, 6)) s = siteinds("S=1/2", g) χ = 4 @@ -29,17 +28,33 @@ function main() ψψ; subgraph_vertices=collect(values(group(v -> v[1], vertices(ψψ)))) ) - println("First testing out a comb tree. Random network with bond dim $χ") + println("\nFirst testing out a comb tree. Random network with bond dim $χ") #Now test out various sequences print("Parallel updates (sequence is irrelevant): ") - belief_propagation(ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision = 1e-10, niters = 100, update_sequence = "parallel") + belief_propagation( + ψψ, + mts_init; + contract_kwargs=(; alg="exact"), + target_precision=1e-10, + niters=100, + edges=[[e] for e in edges(mts_init)], + ) print("Sequential updates (sequence is default edge list of the message tensors): ") - belief_propagation(ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision = 1e-10, niters = 100, update_sequence = "sequential", edges = edges(mts_init)) + belief_propagation( + ψψ, + mts_init; + contract_kwargs=(; alg="exact"), + target_precision=1e-10, + niters=100, + edges=[e for e in edges(mts_init)], + ) print("Sequential updates (sequence is our custom sequence finder): ") - belief_propagation(ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision = 1e-10, niters = 100, update_sequence = "sequential") + belief_propagation( + ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision=1e-10, niters=100 + ) - g = named_grid((6,6)) + g = named_grid((6, 6)) s = siteinds("S=1/2", g) χ = 2 @@ -53,17 +68,33 @@ function main() ψψ; subgraph_vertices=collect(values(group(v -> v[1], vertices(ψψ)))) ) - println("Now testing out a 6x6 grid. Random network with bond dim $χ") + println("\nNow testing out a 6x6 grid. Random network with bond dim $χ") #Now test out various sequences print("Parallel updates (sequence is irrelevant): ") - belief_propagation(ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision = 1e-10, niters = 100, update_sequence = "parallel") + belief_propagation( + ψψ, + mts_init; + contract_kwargs=(; alg="exact"), + target_precision=1e-10, + niters=100, + edges=[[e] for e in edges(mts_init)], + ) print("Sequential updates (sequence is default edge list of the message tensors): ") - belief_propagation(ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision = 1e-10, niters = 100, update_sequence = "sequential", edges = edges(mts_init)) + belief_propagation( + ψψ, + mts_init; + contract_kwargs=(; alg="exact"), + target_precision=1e-10, + niters=100, + edges=[e for e in edges(mts_init)], + ) print("Sequential updates (sequence is our custom sequence finder): ") - belief_propagation(ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision = 1e-10, niters = 100, update_sequence = "sequential") + belief_propagation( + ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision=1e-10, niters=100 + ) - g = NamedGraphs.hexagonal_lattice_graph(4,4) + g = NamedGraphs.hexagonal_lattice_graph(4, 4) s = siteinds("S=1/2", g) χ = 3 @@ -77,15 +108,31 @@ function main() ψψ; subgraph_vertices=collect(values(group(v -> v[1], vertices(ψψ)))) ) - println("Now testing out a 4 x 4 hexagonal lattice. Random network with bond dim $χ") + println("\nNow testing out a 4 x 4 hexagonal lattice. Random network with bond dim $χ") #Now test out various sequences print("Parallel updates (sequence is irrelevant): ") - belief_propagation(ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision = 1e-10, niters = 100, update_sequence = "parallel") + belief_propagation( + ψψ, + mts_init; + contract_kwargs=(; alg="exact"), + target_precision=1e-10, + niters=100, + edges=[[e] for e in edges(mts_init)], + ) print("Sequential updates (sequence is default edge list of the message tensors): ") - belief_propagation(ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision = 1e-10, niters = 100, update_sequence = "sequential", edges = edges(mts_init)) + belief_propagation( + ψψ, + mts_init; + contract_kwargs=(; alg="exact"), + target_precision=1e-10, + niters=100, + edges=[e for e in edges(mts_init)], + ) print("Sequential updates (sequence is our custom sequence finder): ") - belief_propagation(ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision = 1e-10, niters = 100, update_sequence = "sequential") + return belief_propagation( + ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision=1e-10, niters=100 + ) end main() diff --git a/src/beliefpropagation/beliefpropagation.jl b/src/beliefpropagation/beliefpropagation.jl index bb7f376f..b4558993 100644 --- a/src/beliefpropagation/beliefpropagation.jl +++ b/src/beliefpropagation/beliefpropagation.jl @@ -75,31 +75,24 @@ function update_message_tensor( end """ -Do an update of all message tensors for a given ITensornetwork and its partition into sub graphs +Do a sequential update of message tensors on `edges` for a given ITensornetwork and its partition into sub graphs """ function belief_propagation_iteration( tn::ITensorNetwork, - mts::DataGraph; + mts::DataGraph, + edges::Vector{E}; contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1), compute_norm=false, - update_sequence::String="sequential", - edges = edge_update_order(undirected_graph(underlying_graph(mts))), -) +) where {E<:NamedEdge} new_mts = copy(mts) - if update_sequence != "parallel" && update_sequence != "sequential" - error( - "Specified update order is not currently implemented. Choose parallel or sequential." - ) - end - incoming_mts = update_sequence == "parallel" ? mts : new_mts c = 0 for e in edges environment_tensornetworks = ITensorNetwork[ - incoming_mts[e_in] for - e_in in setdiff(boundary_edges(incoming_mts, [src(e)]; dir=:in), [reverse(e)]) + new_mts[e_in] for + e_in in setdiff(boundary_edges(new_mts, [src(e)]; dir=:in), [reverse(e)]) ] new_mts[src(e) => dst(e)] = update_message_tensor( - tn, incoming_mts[src(e)], environment_tensornetworks; contract_kwargs + tn, new_mts[src(e)], environment_tensornetworks; contract_kwargs ) if compute_norm @@ -113,24 +106,96 @@ function belief_propagation_iteration( return new_mts, c / (length(edges)) end +""" +Do parallel updates between groups of edges of all message tensors for a given ITensornetwork and its partition into sub graphs +""" +function belief_propagation_iteration( + tn::ITensorNetwork, + mts::DataGraph, + edge_groups::Vector{Vector{E}}; + contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1), + compute_norm=false, +) where {E<:NamedEdge} + new_mts = copy(mts) + c = 0 + for edges in edge_groups + updated_mts, ct = belief_propagation_iteration( + tn, mts, edges; contract_kwargs, compute_norm + ) + for e in edges + new_mts[e] = updated_mts[e] + end + c += ct + end + return new_mts, c / (length(edge_groups)) +end + +function belief_propagation_iteration( + tn::ITensorNetwork, + mts::DataGraph; + contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1), + compute_norm=false, + edges::Union{Vector{Vector{E}},Vector{E}}=edge_update_order( + undirected_graph(underlying_graph(mts)) + ), +) where {E<:NamedEdge} + return belief_propagation_iteration(tn, mts, edges; contract_kwargs, compute_norm) +end + +# """ +# Do an update of all message tensors for a given ITensornetwork and its partition into sub graphs +# """ +# function belief_propagation_iteration( +# tn::ITensorNetwork, +# mts::DataGraph; +# contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1), +# compute_norm=false, +# update_sequence::String="sequential", +# edges::Vector{Vector{}} = edge_update_order(undirected_graph(underlying_graph(mts))), +# ) +# new_mts = copy(mts) +# if update_sequence != "parallel" && update_sequence != "sequential" +# error( +# "Specified update order is not currently implemented. Choose parallel or sequential." +# ) +# end +# incoming_mts = update_sequence == "parallel" ? mts : new_mts +# c = 0 +# for e in edges +# environment_tensornetworks = ITensorNetwork[ +# incoming_mts[e_in] for +# e_in in setdiff(boundary_edges(incoming_mts, [src(e)]; dir=:in), [reverse(e)]) +# ] +# new_mts[src(e) => dst(e)] = update_message_tensor( +# tn, incoming_mts[src(e)], environment_tensornetworks; contract_kwargs +# ) + +# if compute_norm +# LHS, RHS = ITensors.contract(ITensor(mts[src(e) => dst(e)])), +# ITensors.contract(ITensor(new_mts[src(e) => dst(e)])) +# LHS /= sum(diag(LHS)) +# RHS /= sum(diag(RHS)) +# c += 0.5 * norm(denseblocks(LHS) - denseblocks(RHS)) +# end +# end +# return new_mts, c / (length(edges)) +# end + function belief_propagation( tn::ITensorNetwork, mts::DataGraph; contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1), niters=20, target_precision::Union{Float64,Nothing}=nothing, - update_sequence::String="sequential", - edges = edge_update_order(undirected_graph(underlying_graph(mts))) -) + edges::Union{Vector{Vector{E}},Vector{E}}=edge_update_order( + undirected_graph(underlying_graph(mts)) + ), +) where {E<:NamedEdge} compute_norm = target_precision == nothing ? false : true for i in 1:niters - mts, c = belief_propagation_iteration( - tn, mts; contract_kwargs, compute_norm, update_sequence, edges - ) + mts, c = belief_propagation_iteration(tn, mts, edges; contract_kwargs, compute_norm) if compute_norm && c <= target_precision - println( - "BP converged to desired precision after $i iterations.", - ) + println("BP converged to desired precision after $i iterations.") break end end @@ -144,11 +209,10 @@ function belief_propagation( npartitions=nothing, subgraph_vertices=nothing, niters=20, - update_sequence::String="sequential", target_precision::Union{Float64,Nothing}=nothing, ) mts = message_tensors(tn; nvertices_per_partition, npartitions, subgraph_vertices) - return belief_propagation(tn, mts; contract_kwargs, niters, target_precision, update_sequence) + return belief_propagation(tn, mts; contract_kwargs, niters, target_precision) end """ diff --git a/src/beliefpropagation/sqrt_beliefpropagation.jl b/src/beliefpropagation/sqrt_beliefpropagation.jl index c9ba053e..c4eff09e 100644 --- a/src/beliefpropagation/sqrt_beliefpropagation.jl +++ b/src/beliefpropagation/sqrt_beliefpropagation.jl @@ -1,53 +1,19 @@ # using ITensors: scalartype # using ITensorNetworks: find_subgraph, map_diag, sqrt_diag, boundary_edges -function sqrt_belief_propagation( - tn::ITensorNetwork, - mts::DataGraph; - niters=20, - update_sequence::String="sequential", - # target_precision::Union{Float64,Nothing}=nothing, -) - # compute_norm = target_precision == nothing ? false : true - sqrt_mts = sqrt_message_tensors(tn, mts) - for i in 1:niters - sqrt_mts, c = sqrt_belief_propagation_iteration(tn, sqrt_mts; update_sequence) #; compute_norm) - # if compute_norm && c <= target_precision - # println( - # "Belief Propagation finished. Reached a canonicalness of " * - # string(c) * - # " after $i iterations. ", - # ) - # break - # end - end - return sqr_message_tensors(sqrt_mts) -end - function sqrt_belief_propagation_iteration( - tn::ITensorNetwork, - sqrt_mts::DataGraph; - update_sequence::String="sequential", - edges=edge_update_order(undirected_graph(underlying_graph(mts))), - - # compute_norm=false, -) + tn::ITensorNetwork, sqrt_mts::DataGraph, edges::Vector{E} +) where {E<:NamedEdge} new_sqrt_mts = copy(sqrt_mts) - if update_sequence != "parallel" && update_sequence != "sequential" - error( - "Specified update order is not currently implemented. Choose parallel or sequential." - ) - end - incoming_sqrt_mts = update_sequence == "parallel" ? sqrt_mts : new_sqrt_mts c = 0.0 for e in edges environment_tensornetworks = ITensorNetwork[ - incoming_sqrt_mts[e_in] for - e_in in setdiff(boundary_edges(incoming_sqrt_mts, [src(e)]; dir=:in), [reverse(e)]) + new_sqrt_mts[e_in] for + e_in in setdiff(boundary_edges(new_sqrt_mts, [src(e)]; dir=:in), [reverse(e)]) ] new_sqrt_mts[src(e) => dst(e)] = update_sqrt_message_tensor( - tn, incoming_sqrt_mts[src(e)], environment_tensornetworks; + tn, new_sqrt_mts[src(e)], environment_tensornetworks; ) # if compute_norm @@ -61,6 +27,56 @@ function sqrt_belief_propagation_iteration( return new_sqrt_mts, c / (length(edges)) end +function sqrt_belief_propagation_iteration( + tn::ITensorNetwork, sqrt_mts::DataGraph, edges::Vector{Vector{E}} +) where {E<:NamedEdge} + new_sqrt_mts = copy(sqrt_mts) + c = 0.0 + for e_group in edges + updated_sqrt_mts, ct = sqrt_belief_propagation_iteration(tn, sqr_mts, e_group) + for e in e_group + new_sqrt_mts[e] = updated_sqrt_mts[e] + end + c += ct + end + return new_sqrt_mts, c / (length(edges)) +end + +function sqrt_belief_propagation_iteration( + tn::ITensorNetwork, + sqrt_mts::DataGraph; + edges::Union{Vector{Vector{E}},Vector{E}}=edge_update_order( + undirected_graph(underlying_graph(mts)) + ), +) where {E<:NamedEdge} + return sqrt_belief_propagation_iteration(tn, sqrt_mts, edges) +end + +function sqrt_belief_propagation( + tn::ITensorNetwork, + mts::DataGraph; + niters=20, + edges::Union{Vector{Vector{E}},Vector{E}}=edge_update_order( + undirected_graph(underlying_graph(mts)) + ), + # target_precision::Union{Float64,Nothing}=nothing, +) where {E<:NamedEdge} + # compute_norm = target_precision == nothing ? false : true + sqrt_mts = sqrt_message_tensors(tn, mts) + for i in 1:niters + sqrt_mts, c = sqrt_belief_propagation_iteration(tn, sqrt_mts, edges) #; compute_norm) + # if compute_norm && c <= target_precision + # println( + # "Belief Propagation finished. Reached a canonicalness of " * + # string(c) * + # " after $i iterations. ", + # ) + # break + # end + end + return sqr_message_tensors(sqrt_mts) +end + function update_sqrt_message_tensor( tn::ITensorNetwork, subgraph_vertices::Vector, sqrt_mts::Vector{ITensorNetwork}; ) diff --git a/src/gauging.jl b/src/gauging.jl index 79aca813..4086198c 100644 --- a/src/gauging.jl +++ b/src/gauging.jl @@ -17,7 +17,7 @@ function vidal_gauge( bond_tensors::DataGraph; eigen_message_tensor_cutoff=10 * eps(real(scalartype(ψ))), regularization=10 * eps(real(scalartype(ψ))), - edges = NamedGraphs.edges(ψ), + edges=NamedGraphs.edges(ψ), svd_kwargs..., ) ψ_vidal = copy(ψ) @@ -80,11 +80,12 @@ function vidal_gauge( mts::DataGraph; eigen_message_tensor_cutoff=10 * eps(real(scalartype(ψ))), regularization=10 * eps(real(scalartype(ψ))), + edges=NamedGraphs.edges(ψ), svd_kwargs..., ) bond_tensors = initialize_bond_tensors(ψ) return vidal_gauge( - ψ, mts, bond_tensors; eigen_message_tensor_cutoff, regularization, svd_kwargs... + ψ, mts, bond_tensors; eigen_message_tensor_cutoff, regularization, edges, svd_kwargs... ) end @@ -95,7 +96,6 @@ function vidal_gauge( regularization=10 * eps(real(scalartype(ψ))), niters=30, target_canonicalness::Union{Nothing,Float64}=nothing, - update_sequence = "sequential", svd_kwargs..., ) ψψ = norm_network(ψ) @@ -103,7 +103,7 @@ function vidal_gauge( mts = message_tensors(Z) mts = belief_propagation( - ψψ, mts; contract_kwargs=(; alg="exact"), niters, target_precision=target_canonicalness, update_sequence + ψψ, mts; contract_kwargs=(; alg="exact"), niters, target_precision=target_canonicalness ) return vidal_gauge( ψ, mts; eigen_message_tensor_cutoff, regularization, niters, svd_kwargs... @@ -175,7 +175,11 @@ function symmetric_to_vidal_gauge( end """Function to measure the 'isometries' of a state in the Vidal Gauge""" -function vidal_itn_isometries(ψ::ITensorNetwork, bond_tensors::DataGraph; edges = vcat(NamedGraphs.edges(ψ), reverse.(NamedGraphs.edges(ψ)))) +function vidal_itn_isometries( + ψ::ITensorNetwork, + bond_tensors::DataGraph; + edges=vcat(NamedGraphs.edges(ψ), reverse.(NamedGraphs.edges(ψ))), +) isometries = DataGraph{vertextype(ψ),ITensor,ITensor}(directed_graph(underlying_graph(ψ))) for e in edges diff --git a/src/utils.jl b/src/utils.jl index 1d821dc8..e3e2bc82 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -40,8 +40,10 @@ function edge_update_order(g) end #Find an optimal ordering of the edges in a tree -function tree_edge_update_order(g::AbstractNamedGraph; root_vertex = first(keys(sort(eccentricities(g); rev=true)))) +function tree_edge_update_order( + g::AbstractNamedGraph; root_vertex=first(keys(sort(eccentricities(g); rev=true))) +) @assert is_tree(g) es = post_order_dfs_edges(g, root_vertex) return vcat(es, reverse(reverse.(es))) -end \ No newline at end of file +end diff --git a/test/test_belief_propagation.jl b/test/test_belief_propagation.jl index dff03220..82a2db66 100644 --- a/test/test_belief_propagation.jl +++ b/test/test_belief_propagation.jl @@ -40,13 +40,7 @@ ITensors.disable_warn_order() Z = partition(ψψ; subgraph_vertices=collect(values(group(v -> v[1], vertices(ψψ))))) mts = message_tensors(Z) - mts = belief_propagation( - ψψ, - mts; - contract_kwargs=(; alg="exact"), - niters = 1, - update_sequence="sequential", - ) + mts = belief_propagation(ψψ, mts; contract_kwargs=(; alg="exact"), niters=1) numerator_network = approx_network_region( ψψ, mts, [(v, 1)]; verts_tn=ITensorNetwork(ITensor[apply(op("Sz", s[v]), ψ[v])]) @@ -57,7 +51,7 @@ ITensors.disable_warn_order() @test abs.(bp_sz - exact_sz) <= 1e-14 #Now test on a tree, should also be exact - g = named_comb_tree((6,6)) + g = named_comb_tree((6, 6)) s = siteinds("S=1/2", g) χ = 2 Random.seed!(1564) @@ -73,13 +67,7 @@ ITensors.disable_warn_order() Z = partition(ψψ; subgraph_vertices=collect(values(group(v -> v[1], vertices(ψψ))))) mts = message_tensors(Z) - mts = belief_propagation( - ψψ, - mts; - contract_kwargs=(; alg="exact"), - niters = 1, - update_sequence="sequential", - ) + mts = belief_propagation(ψψ, mts; contract_kwargs=(; alg="exact"), niters=1) numerator_network = approx_network_region( ψψ, mts, [(v, 1)]; verts_tn=ITensorNetwork(ITensor[apply(op("Sz", s[v]), ψ[v])]) @@ -112,7 +100,8 @@ ITensors.disable_warn_order() ) denominator_network = approx_network_region(ψψ, mts, vs) - bp_szsz = ITensors.contract(numerator_network)[] / ITensors.contract(denominator_network)[] + bp_szsz = + ITensors.contract(numerator_network)[] / ITensors.contract(denominator_network)[] @test abs.(bp_szsz - actual_szsz) <= 0.05 @@ -176,8 +165,7 @@ ITensors.disable_warn_order() contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, cutoff=1e-16, maxdim ), - niters = 1, - update_sequence="sequential", + niters=1, ) numerator_network = approx_network_region(ψψ, mts, [v]; verts_tn=ITensorNetwork(ψOψ[v])) @@ -189,4 +177,4 @@ ITensors.disable_warn_order() contract_boundary_mps(ψOψ; cutoff=1e-16) / contract_boundary_mps(ψψ; cutoff=1e-16) @test abs.(bp_sz - exact_sz) <= 1e-5 -end \ No newline at end of file +end