From 1c027cb34cb3e3778ae1a2b6ea94321cb69a79fa Mon Sep 17 00:00:00 2001 From: Joseph Tindall Date: Tue, 14 May 2024 12:32:10 -0400 Subject: [PATCH 1/4] Account for edge case where network evaluates to 0 --- src/caches/beliefpropagationcache.jl | 5 ++++- src/contract.jl | 6 +++++- test/test_belief_propagation.jl | 11 +++++++++-- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/caches/beliefpropagationcache.jl b/src/caches/beliefpropagationcache.jl index 2ce338f3..e583acf7 100644 --- a/src/caches/beliefpropagationcache.jl +++ b/src/caches/beliefpropagationcache.jl @@ -19,7 +19,10 @@ default_message_norm(m::ITensor) = norm(m) function default_message_update(contract_list::Vector{ITensor}; kwargs...) sequence = optimal_contraction_sequence(contract_list) updated_messages = contract(contract_list; sequence, kwargs...) - updated_messages /= norm(updated_messages) + n = norm(updated_messages) + if !iszero(n) + updated_messages /= norm(updated_messages) + end return ITensor[updated_messages] end @traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : nothing diff --git a/src/contract.jl b/src/contract.jl index 0fc575a6..e6bbaef7 100644 --- a/src/contract.jl +++ b/src/contract.jl @@ -73,7 +73,11 @@ function logscalar( denominator_terms end - return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) + if any(t -> iszero(t), collect(denominator_terms)) + return -Inf + else + return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) + end end function ITensors.scalar(alg::Algorithm"bp", tn::AbstractITensorNetwork; kwargs...) diff --git a/test/test_belief_propagation.jl b/test/test_belief_propagation.jl index 66cf25d7..ff8aa2e9 100644 --- a/test/test_belief_propagation.jl +++ b/test/test_belief_propagation.jl @@ -2,7 +2,6 @@ using Compat: Compat using Graphs: vertices # Trigger package extension. -using GraphsFlows: GraphsFlows using ITensorNetworks: ITensorNetworks, BeliefPropagationCache, @@ -18,6 +17,7 @@ using ITensorNetworks: message, partitioned_tensornetwork, random_tensornetwork, + scalar, siteinds, split_index, tensornetwork, @@ -28,7 +28,7 @@ using ITensors: ITensors, ITensor, combiner, dag, inds, inner, op, prime, random using ITensorNetworks.ModelNetworks: ModelNetworks using ITensors.NDTensors: array using LinearAlgebra: eigvals, tr -using NamedGraphs: NamedEdge +using NamedGraphs: NamedEdge, NamedGraph using NamedGraphs.NamedGraphGenerators: named_comb_tree, named_grid using NamedGraphs.PartitionedGraphs: PartitionVertex, partitionedges using Random: Random @@ -75,5 +75,12 @@ using Test: @test, @testset @test all(eig -> imag(eig) ≈ 0, eigs) @test all(eig -> real(eig) >= -eps(eltype(eig)), eigs) + + #Test edge case of network which evalutes to 0 + χ = 2 + g = named_grid((3, 1)) + ψ = random_tensornetwork(ComplexF64, g; link_space=χ) + ψ[(1, 1)] = 0.0 * ψ[(1, 1)] + @test iszero(scalar(ψ; alg="bp")) end end From 15e185fce0ce0acc169f706cf172822ab63d750e Mon Sep 17 00:00:00 2001 From: Joseph Tindall Date: Tue, 14 May 2024 14:39:12 -0400 Subject: [PATCH 2/4] Remove collect() --- src/contract.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/contract.jl b/src/contract.jl index e6bbaef7..93f4f77d 100644 --- a/src/contract.jl +++ b/src/contract.jl @@ -73,7 +73,7 @@ function logscalar( denominator_terms end - if any(t -> iszero(t), collect(denominator_terms)) + if any(t -> iszero(t), denominator_terms) return -Inf else return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) From 003475a9cdd4a3ec4e8ed0ab7fbae11b7e37ef18 Mon Sep 17 00:00:00 2001 From: Joseph Tindall Date: Tue, 14 May 2024 15:46:34 -0400 Subject: [PATCH 3/4] Simplify logscalar(bp) --- src/contract.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/contract.jl b/src/contract.jl index 93f4f77d..4adb0e10 100644 --- a/src/contract.jl +++ b/src/contract.jl @@ -73,11 +73,8 @@ function logscalar( denominator_terms end - if any(t -> iszero(t), denominator_terms) - return -Inf - else - return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) - end + any(iszero, denominator_terms) && return -Inf + return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) end function ITensors.scalar(alg::Algorithm"bp", tn::AbstractITensorNetwork; kwargs...) From 0a96fd316bee069ae635d5ae1d2f4b4481163329 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Tue, 14 May 2024 19:06:41 -0400 Subject: [PATCH 4/4] Small optimization and improved readability --- src/caches/beliefpropagationcache.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/caches/beliefpropagationcache.jl b/src/caches/beliefpropagationcache.jl index e583acf7..b980a52f 100644 --- a/src/caches/beliefpropagationcache.jl +++ b/src/caches/beliefpropagationcache.jl @@ -19,9 +19,9 @@ default_message_norm(m::ITensor) = norm(m) function default_message_update(contract_list::Vector{ITensor}; kwargs...) sequence = optimal_contraction_sequence(contract_list) updated_messages = contract(contract_list; sequence, kwargs...) - n = norm(updated_messages) - if !iszero(n) - updated_messages /= norm(updated_messages) + message_norm = norm(updated_messages) + if !iszero(message_norm) + updated_messages /= message_norm end return ITensor[updated_messages] end