Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Normalize #192

Open
wants to merge 59 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
ec7ec3b
New BP alternating update
JoeyT1994 May 6, 2024
bd05519
Working BP DMRG Solver
JoeyT1994 May 9, 2024
e116388
Merge remote-tracking branch 'upstream/main' into bp_alternating_update
JoeyT1994 May 9, 2024
cd2b139
New Changes
JoeyT1994 May 14, 2024
6391bfa
Merge remote-tracking branch 'upstream/main' into bp_alternating_update
JoeyT1994 May 14, 2024
fa91e7c
Merge remote-tracking branch 'upstream/main' into bp_alternating_update
JoeyT1994 May 15, 2024
201882a
Small changes
JoeyT1994 May 16, 2024
7228fb5
Changes
JoeyT1994 May 31, 2024
75d0c3b
Utils additions
JoeyT1994 May 31, 2024
c90139b
More stuff
JoeyT1994 Jun 2, 2024
e87e1b3
Big Improvements
JoeyT1994 Jun 7, 2024
8d780a8
Refactor code
JoeyT1994 Jun 7, 2024
e62ae0f
Save stuff
JoeyT1994 Jun 11, 2024
371492d
Commit 1
JoeyT1994 Jun 12, 2024
5138e51
Changes
JoeyT1994 Jun 12, 2024
275191a
Changes
JoeyT1994 Jun 12, 2024
194fba3
working implementation
JoeyT1994 Jun 12, 2024
50369c1
working implementation
JoeyT1994 Jun 12, 2024
0e5e5d8
Remove old changes
JoeyT1994 Jun 12, 2024
4bc0183
Revert
JoeyT1994 Jun 12, 2024
9e14f14
Revert
JoeyT1994 Jun 12, 2024
0a7355e
Revert
JoeyT1994 Jun 12, 2024
b07b978
Revert
JoeyT1994 Jun 12, 2024
440c267
Revert
JoeyT1994 Jun 12, 2024
ed7befa
Remove files
JoeyT1994 Jun 12, 2024
322dca4
Revert
JoeyT1994 Jun 12, 2024
54f41c0
Revert
JoeyT1994 Jun 12, 2024
dc0e132
Revert
JoeyT1994 Jun 12, 2024
2af3984
revert
JoeyT1994 Jun 12, 2024
30786bc
Working version
JoeyT1994 Jun 14, 2024
f0d4fc8
Merge branch 'ITensor:main' into bp_dmrg_alt_method
JoeyT1994 Jun 14, 2024
ed5037e
Improvements
JoeyT1994 Jun 14, 2024
e61e58c
Merge remote-tracking branch 'upstream/main' into bp_dmrg_alt_method
JoeyT1994 Jun 14, 2024
6998077
merge
JoeyT1994 Jun 14, 2024
511e09f
Merge branch 'bp_dmrg_alt_method' of github.com:JoeyT1994/ITensorNetw…
JoeyT1994 Jun 14, 2024
ed0c069
Improvements
JoeyT1994 Jun 14, 2024
553a983
Simplify
JoeyT1994 Jun 15, 2024
005b0e5
Change
JoeyT1994 Jun 16, 2024
af68e63
Working first commit
JoeyT1994 Jun 16, 2024
0704609
Revert some files
JoeyT1994 Jun 16, 2024
e1344f0
Revert expect
JoeyT1994 Jun 16, 2024
66319b0
Revert some changes
JoeyT1994 Jun 16, 2024
b098d44
Update src/caches/beliefpropagationcache.jl
JoeyT1994 Jun 16, 2024
b296277
Update src/caches/beliefpropagationcache.jl
JoeyT1994 Jun 16, 2024
1c87d22
Update src/normalize.jl
JoeyT1994 Jun 16, 2024
f88b21c
Merge remote-tracking branch 'upstream/main' into normalize!
JoeyT1994 Jun 26, 2024
6a8d4b9
Renormalize messages against themselves first
JoeyT1994 Jun 26, 2024
c845947
Blah
JoeyT1994 Sep 13, 2024
90c7251
Merge remote-tracking branch 'origin/main'
JoeyT1994 Oct 17, 2024
86f3087
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Oct 17, 2024
6ff0cd5
Bug fix in current ortho. Change test
JoeyT1994 Oct 17, 2024
34e8e5e
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Nov 22, 2024
d096722
Fix bug
JoeyT1994 Nov 26, 2024
70a3f7e
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Dec 5, 2024
2cb7f85
Refactor and bring down upstream changes
JoeyT1994 Dec 10, 2024
73e9e1e
Merge remote-tracking branch 'origin/main' into normalize!
JoeyT1994 Dec 10, 2024
4f4e2e5
Remove erroneous file
JoeyT1994 Dec 10, 2024
620da37
Allow rescaling flat networks with bp
JoeyT1994 Dec 10, 2024
180183e
Make generic to other algorithms
JoeyT1994 Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ include("solvers/linsolve.jl")
include("solvers/sweep_plans/sweep_plans.jl")
include("apply.jl")
include("inner.jl")
include("normalize.jl")
include("expect.jl")
include("environment.jl")
include("exports.jl")
Expand Down
33 changes: 31 additions & 2 deletions src/caches/beliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ using SimpleTraits: SimpleTraits, Not, @traitfn
default_message(inds_e) = ITensor[denseblocks(delta(i)) for i in inds_e]
default_messages(ptn::PartitionedGraph) = Dictionary()
default_message_norm(m::ITensor) = norm(m)
function default_message_update(contract_list::Vector{ITensor}; kwargs...)
function default_message_update(contract_list::Vector{ITensor}; normalize=true, kwargs...)
sequence = optimal_contraction_sequence(contract_list)
updated_messages = contract(contract_list; sequence, kwargs...)
message_norm = norm(updated_messages)
if !iszero(message_norm)
if !iszero(message_norm) && normalize
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
updated_messages /= message_norm
end
return ITensor[updated_messages]
Expand Down Expand Up @@ -157,6 +157,15 @@ function environment(bp_cache::BeliefPropagationCache, verts::Vector)
return vcat(messages, central_tensors)
end

function factors(bp_cache::BeliefPropagationCache, vertices::Vector)
tn = tensornetwork(bp_cache)
return ITensor[tn[vertex] for vertex in vertices]
end
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved

function factor(bp_cache::BeliefPropagationCache, vertex)
return only(factors(bp_cache, [vertex]))
end

function factor(bp_cache::BeliefPropagationCache, vertex::PartitionVertex)
ptn = partitioned_tensornetwork(bp_cache)
return collect(eachtensor(subgraph(ptn, vertex)))
Expand Down Expand Up @@ -309,3 +318,23 @@ end
function scalar_factors_quotient(bp_cache::BeliefPropagationCache)
return vertex_scalars(bp_cache), edge_scalars(bp_cache)
end

function normalize_messages(bp_cache::BeliefPropagationCache, pes::Vector{<:PartitionEdge})
bp_cache = copy(bp_cache)
mts = messages(bp_cache)
for pe in pes
n = region_scalar(bp_cache, pe)
me, mer = only(mts[pe]), only(mts[reverse(pe)])
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
set!(mts, pe, ITensor[(1 / sqrt(n)) * me])
set!(mts, reverse(pe), ITensor[(1 / sqrt(n)) * mer])
end
return bp_cache
end

function normalize_message(bp_cache::BeliefPropagationCache, pe::PartitionEdge)
return normalize_messages(bp_cache, PartitionEdge[pe])
end

function normalize_messages(bp_cache::BeliefPropagationCache)
return normalize_messages(bp_cache, partitionedges(partitioned_tensornetwork(bp_cache)))
end
56 changes: 56 additions & 0 deletions src/normalize.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
using LinearAlgebra

function LinearAlgebra.normalize(tn::AbstractITensorNetwork; alg="exact", kwargs...)
return normalize(Algorithm(alg), tn; kwargs...)
end

function LinearAlgebra.normalize(alg::Algorithm"exact", tn::AbstractITensorNetwork)
norm_tn = norm_sqr_network(tn)
log_norm = logscalar(alg, norm_tn)
tn = copy(tn)
L = length(vertices(tn))
c = exp(log_norm / L)
for v in vertices(tn)
tn[v] = tn[v] / sqrt(c)
end
return tn
end

function LinearAlgebra.normalize(
alg::Algorithm"bp",
tn::AbstractITensorNetwork;
(cache!)=nothing,
update_cache=isnothing(cache!),
cache_update_kwargs=default_cache_update_kwargs(cache!),
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it seems like a basic design question here is if normalizing should refer to treating the tn as a state that should be normalized to 1 or as something such that you want the result of contract(tn) to be 1.

It seems reasonable to define it such that tn is a state where you want contract(norm_network(tn)) to be 1 as you do here, however it may be good to write it in terms of an inner function that takes a tensor network and returns a new one where the tensors are scaled such that contracting it is 1. I can't think of a good name for that right now, but for the time being I'll refer to it as rescale(tn::AbstractITensorNetwork), so scalar(rescale(tn)) == 1 for any input tn, and the input has to be a closed network that evaluates to a scalar. Then we can just define normalize(tn) = ket_network(rescale(norm_network(tn))) or something like that.

The current implementation feels a bit too "in the weeds" dealing with quadratic forms, bras, kets, etc. and seems like something that could be abstracted and generalized.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also defining a function like rescale then would be relevant for other kinds of networks like partition functions, where if you track the normalization factors then that gives the evaluation of the partition function.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Relatedly, rescale(tn::AbstractITensorNetwork) could be defined in two steps, one where it computes the local scale factors (I think there is already a function for that?) and then a next step where it just divides the factors by those scale factors, so the implementation could be a bit simpler by dividing it into multiple generic steps.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I see what you mean, that's a nice idea to split it apart like that. Will change it to do that

Copy link
Contributor Author

@JoeyT1994 JoeyT1994 Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I split it apart based on a rescale function.

if isnothing(cache!)
cache! = Ref(BeliefPropagationCache(QuadraticFormNetwork(tn)))
end

if update_cache
cache![] = update(cache![]; cache_update_kwargs...)
end

tn = copy(tn)
cache![] = normalize_messages(cache![])
norm_tn = tensornetwork(cache![])

vertices_states = Dictionary()
for v in vertices(tn)
v_ket, v_bra = ket_vertex(norm_tn, v), bra_vertex(norm_tn, v)
pv = only(partitionvertices(cache![], [v_ket]))
vn = region_scalar(cache![], pv)
state = (1.0 / sqrt(vn)) * tn[v]
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
state_dag = copy(dag(state))
state_dag = replaceinds(
state_dag, inds(state_dag), dual_index_map(norm_tn).(inds(state_dag))
)
set!(vertices_states, v_ket, state)
set!(vertices_states, v_bra, state_dag)
tn[v] = state
end

cache![] = update_factors(cache![], vertices_states)

return tn
end
52 changes: 52 additions & 0 deletions test/test_normalize.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
@eval module $(gensym())
using ITensorNetworks:
BeliefPropagationCache,
QuadraticFormNetwork,
edge_scalars,
norm_sqr_network,
random_tensornetwork,
vertex_scalars
using ITensors: dag, inner, siteinds, scalar
using Graphs: SimpleGraph, uniform_tree
using LinearAlgebra: normalize
using NamedGraphs: NamedGraph
using NamedGraphs.NamedGraphGenerators: named_grid
using StableRNGs: StableRNG
using Test: @test, @testset
@testset "Normalize" begin

#First lets do a tree
L = 6
χ = 2
rng = StableRNG(1234)

g = NamedGraph(SimpleGraph(uniform_tree(L)))
s = siteinds("S=1/2", g)
x = random_tensornetwork(rng, s; link_space=χ)

ψ = normalize(x; alg="exact")
@test scalar(norm_sqr_network(ψ); alg="exact") ≈ 1.0

ψ = normalize(x; alg="bp")
@test scalar(norm_sqr_network(ψ); alg="exact") ≈ 1.0

#Now a loopy graph
Lx, Ly = 3, 2
χ = 2
rng = StableRNG(1234)

g = named_grid((Lx, Ly))
s = siteinds("S=1/2", g)
x = random_tensornetwork(rng, s; link_space=χ)

ψ = normalize(x; alg="exact")
@test scalar(norm_sqr_network(ψ); alg="exact") ≈ 1.0

ψIψ_bpc = Ref(BeliefPropagationCache(QuadraticFormNetwork(x)))
ψ = normalize(x; alg="bp", (cache!)=ψIψ_bpc, update_cache=true)
ψIψ_bpc = ψIψ_bpc[]
@test all(x -> x ≈ 1.0, edge_scalars(ψIψ_bpc))
@test all(x -> x ≈ 1.0, vertex_scalars(ψIψ_bpc))
@test scalar(QuadraticFormNetwork(ψ); alg="bp") ≈ 1.0
end
end
Loading