Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor expect (single site) #162

Merged
merged 15 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2021 Matthew Fishman <[email protected]> and contributors
Copyright (c) 2021 Matthew Fishman <[email protected]>, Joseph Tindall <[email protected]> and contributors

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ITensorNetworks"
uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7"
authors = ["Matthew Fishman <[email protected]> and contributors"]
authors = ["Matthew Fishman <[email protected]>, Joseph Tindall <[email protected]> and contributors"]
version = "0.10.0"

[deps]
Expand Down
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ DocMeta.setdocmeta!(

makedocs(;
modules=[ITensorNetworks],
authors="Matthew Fishman <[email protected]> and contributors",
authors="Matthew Fishman <[email protected]>, Joseph Tindall <[email protected]> and contributors",
repo="https://github.com/mtfishman/ITensorNetworks.jl/blob/{commit}{path}#{line}",
sitename="ITensorNetworks.jl",
format=Documenter.HTML(;
Expand Down
98 changes: 54 additions & 44 deletions src/expect.jl
Original file line number Diff line number Diff line change
@@ -1,57 +1,67 @@
using ITensors.ITensorMPS: ITensorMPS, expect, promote_itensor_eltype, OpSum
using Dictionaries: Dictionary, set!
using ITensors: Op, op, contract, siteinds, which_op
using ITensors.ITensorMPS: ITensorMPS, expect

default_expect_alg() = "bp"

function ITensorMPS.expect(
op::String,
ψ::AbstractITensorNetwork;
cutoff=nothing,
maxdim=nothing,
ortho=false,
sequence=nothing,
vertices=vertices(ψ),
ψ::AbstractITensorNetwork, args...; alg=default_expect_alg(), kwargs...
)
s = siteinds(ψ)
ElT = promote_itensor_eltype(ψ)
# ElT = ishermitian(ITensors.op(op, s[vertices[1]])) ? real(ElT) : ElT
res = Dictionary(vertices, Vector{ElT}(undef, length(vertices)))
if isnothing(sequence)
sequence = contraction_sequence(inner_network(ψ, ψ))
end
normψ² = norm_sqr(ψ; alg="exact", sequence)
for v in vertices
O = ITensor(Op(op, v), s)
Oψ = apply(O, ψ; cutoff, maxdim, ortho)
res[v] = inner(ψ, Oψ; alg="exact", sequence) / normψ²
end
return res
return expect(Algorithm(alg), ψ, args...; kwargs...)
end

function expect_internal(ψIψ::AbstractFormNetwork, op::Op; contract_kwargs=(;), kwargs...)
v = only(op.sites)
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
ψIψ_v = ψIψ[operator_vertex(ψIψ, v)]
s = commonind(ψIψ[ket_vertex(ψIψ, v)], ψIψ_v)
operator = ITensors.op(op.which_op, s)
∂ψIψ_∂v = environment(ψIψ, [v]; vertex_mapping_function=operator_vertices, kwargs...)
numerator = contract(vcat(∂ψIψ_∂v, operator); contract_kwargs...)[]
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
denominator = contract(vcat(∂ψIψ_∂v, ψIψ_v); contract_kwargs...)[]

return numerator / denominator
end

function ITensorMPS.expect(
ℋ::OpSum,
ψ::AbstractITensorNetwork;
cutoff=nothing,
maxdim=nothing,
ortho=false,
sequence=nothing,
alg::Algorithm,
ψ::AbstractITensorNetwork,
ops;
(cache!)=nothing,
update_cache=isnothing(cache!),
cache_update_kwargs=default_cache_update_kwargs(cache!),
cache_construction_function=tn ->
cache(alg, tn; default_cache_construction_kwargs(alg, tn)...),
kwargs...,
)
s = siteinds(ψ)
# h⃗ = Vector{ITensor}(ℋ, s)
if isnothing(sequence)
sequence = contraction_sequence(inner_network(ψ, ψ))
ψIψ = inner_network(ψ, ψ)
if isnothing(cache!)
cache! = Ref(cache_construction_function(ψIψ))
end

if update_cache
cache![] = update(cache![]; cache_update_kwargs...)
end
h⃗ψ = [apply(hᵢ, ψ; cutoff, maxdim, ortho) for hᵢ in ITensors.terms(ℋ)]
ψhᵢψ = [inner(ψ, hᵢψ; alg="exact", sequence) for hᵢψ in h⃗ψ]
ψh⃗ψ = sum(ψhᵢψ)
ψψ = norm_sqr(ψ; alg="exact", sequence)
return ψh⃗ψ / ψψ

return map(
op -> expect_internal(ψIψ, op; alg, cache!, update_cache=false, kwargs...), ops
)
end

function ITensorMPS.expect(alg::Algorithm"exact", ψ::AbstractITensorNetwork, ops; kwargs...)
ψIψ = inner_network(ψ, ψ)
return map(op -> expect_internal(ψIψ, op; alg, kwargs...), ops)
end

function ITensorMPS.expect(alg::Algorithm, ψ::AbstractITensorNetwork, op::Op; kwargs...)
return expect(alg, ψ, [op]; kwargs...)
end

function ITensorMPS.expect(
opsum_sum::Sum{<:OpSum},
ψ::AbstractITensorNetwork;
cutoff=nothing,
maxdim=nothing,
ortho=true,
sequence=nothing,
alg::Algorithm, ψ::AbstractITensorNetwork, op::String, vertices; kwargs...
)
return expect(sum(Ops.terms(opsum_sum)), ψ; cutoff, maxdim, ortho, sequence)
return expect(alg, ψ, [Op(op, vertex) for vertex in vertices]; kwargs...)
end

function ITensorMPS.expect(alg::Algorithm, ψ::AbstractITensorNetwork, op::String; kwargs...)
return expect(alg, ψ, op, vertices(ψ); kwargs...)
end
10 changes: 8 additions & 2 deletions src/formnetworks/abstractformnetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ end
function operator_vertices(f::AbstractFormNetwork)
return filter(v -> last(v) == operator_vertex_suffix(f), vertices(f))
end

function bra_vertices(f::AbstractFormNetwork)
return filter(v -> last(v) == bra_vertex_suffix(f), vertices(f))
end
Expand All @@ -31,6 +32,10 @@ function ket_vertices(f::AbstractFormNetwork)
return filter(v -> last(v) == ket_vertex_suffix(f), vertices(f))
end

function operator_vertices(f::AbstractFormNetwork, original_state_vertices::Vector)
return [operator_vertex_map(f)(osv) for osv in original_state_vertices]
end

function bra_vertices(f::AbstractFormNetwork, original_state_vertices::Vector)
return [bra_vertex_map(f)(osv) for osv in original_state_vertices]
end
Expand Down Expand Up @@ -69,11 +74,12 @@ end

function environment(
f::AbstractFormNetwork,
original_state_vertices::Vector;
original_vertices::Vector;
vertex_mapping_function=state_vertices,
alg=default_environment_algorithm(),
kwargs...,
)
form_vertices = state_vertices(f, original_state_vertices)
form_vertices = vertex_mapping_function(f, original_vertices)
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
if alg == "bp"
partitioned_vertices = group(v -> original_state_vertex(f, v), vertices(f))
return environment(
Expand Down
129 changes: 16 additions & 113 deletions test/test_belief_propagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,125 +6,62 @@ using GraphsFlows: GraphsFlows
using ITensorNetworks:
ITensorNetworks,
BeliefPropagationCache,
IndsNetwork,
ITensorNetwork,
⊗,
apply,
combine_linkinds,
contract,
contract_boundary_mps,
contraction_sequence,
eachtensor,
environment,
flatten_networks,
inner_network,
linkinds_combiners,
message,
partitioned_tensornetwork,
random_tensornetwork,
siteinds,
split_index,
tensornetwork,
update,
update_factor
update_factor,
update_message
using ITensors: ITensors, ITensor, combiner, dag, inds, inner, op, prime, randomITensor
using ITensorNetworks.ModelNetworks: ModelNetworks
using ITensors.NDTensors: array
using LinearAlgebra: eigvals, tr
using NamedGraphs: NamedEdge
using NamedGraphs.NamedGraphGenerators: named_comb_tree, named_grid
using NamedGraphs.PartitionedGraphs: PartitionVertex
using NamedGraphs.PartitionedGraphs: PartitionVertex, partitionedges
using Random: Random
using SplitApplyCombine: group
using Test: @test, @testset

@testset "belief_propagation" begin
ITensors.disable_warn_order()

#First test on an MPS, should be exact
g_dims = (1, 6)
g = named_grid(g_dims)
g = named_grid((3, 3))
s = siteinds("S=1/2", g)
χ = 4
χ = 2
Random.seed!(1234)
ψ = random_tensornetwork(s; link_space=χ)

ψψ = ψ ⊗ prime(dag(ψ); sites=[])

v = (1, 3)

Oψ = copy(ψ)
Oψ[v] = apply(op("Sz", s[v]), ψ[v])
exact_sz = inner(Oψ, ψ) / inner(ψ, ψ)
bpc = BeliefPropagationCache(ψψ)
bpc = update(bpc; maxiter=50, tol=1e-10)

bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
bpc = update(bpc)
env_tensors = environment(bpc, [PartitionVertex(v)])
numerator = contract(vcat(env_tensors, ITensor[ψ[v], op("Sz", s[v]), dag(prime(ψ[v]))]))[]
denominator = contract(vcat(env_tensors, ITensor[ψ[v], op("I", s[v]), dag(prime(ψ[v]))]))[]

@test abs.((numerator / denominator) - exact_sz) <= 1e-14
#Test messages are converged
for pe in partitionedges(partitioned_tensornetwork(bpc))
@test update_message(bpc, pe) ≈ message(bpc, pe) atol = 1e-8
end

#Test updating the underlying tensornetwork in the cache
v = first(vertices(ψψ))
new_tensor = randomITensor(inds(ψψ[v]))
bpc = update_factor(bpc, v, new_tensor)
ψψ_updated = tensornetwork(bpc)
bpc_updated = update_factor(bpc, v, new_tensor)
ψψ_updated = tensornetwork(bpc_updated)
@test ψψ_updated[v] == new_tensor

#Now test on a tree, should also be exact
g = named_comb_tree((4, 4))
s = siteinds("S=1/2", g)
χ = 2
Random.seed!(1564)
ψ = random_tensornetwork(s; link_space=χ)

ψψ = ψ ⊗ prime(dag(ψ); sites=[])

v = (1, 3)

Oψ = copy(ψ)
Oψ[v] = apply(op("Sz", s[v]), ψ[v])
exact_sz = inner(Oψ, ψ) / inner(ψ, ψ)

bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
bpc = update(bpc)
env_tensors = environment(bpc, [PartitionVertex(v)])
numerator = contract(vcat(env_tensors, ITensor[ψ[v], op("Sz", s[v]), dag(prime(ψ[v]))]))[]
denominator = contract(vcat(env_tensors, ITensor[ψ[v], op("I", s[v]), dag(prime(ψ[v]))]))[]

@test abs.((numerator / denominator) - exact_sz) <= 1e-14

#Now test two-site expec taking on the partition function of the Ising model. Not exact, but close
g_dims = (3, 4)
g = named_grid(g_dims)
s = IndsNetwork(g; link_space=2)
beta, h = 0.3, 0.5
vs = [(2, 3), (3, 3)]
ψψ = ModelNetworks.ising_network(s, beta; h)
ψOψ = ModelNetworks.ising_network(s, beta; h, szverts=vs)

contract_seq = contraction_sequence(ψψ)
actual_szsz =
contract(ψOψ; sequence=contract_seq)[] / contract(ψψ; sequence=contract_seq)[]

bpc = BeliefPropagationCache(ψψ, group(v -> v, vertices(ψψ)))
bpc = update(bpc; maxiter=20, verbose=true, tol=1e-5)

env_tensors = environment(bpc, vs)
numerator = contract(vcat(env_tensors, ITensor[ψOψ[v] for v in vs]))[]
denominator = contract(vcat(env_tensors, ITensor[ψψ[v] for v in vs]))[]

@test abs.((numerator / denominator) - actual_szsz) <= 0.05

#Test forming a two-site RDM. Check it has the correct size, trace 1 and is PSD
g_dims = (3, 3)
g = named_grid(g_dims)
s = siteinds("S=1/2", g)
vs = [(2, 2), (2, 3)]
χ = 3
ψ = random_tensornetwork(s; link_space=χ)
ψψ = ψ ⊗ prime(dag(ψ); sites=[])

bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
bpc = update(bpc; maxiter=20)

ψψsplit = split_index(ψψ, NamedEdge.([(v, 1) => (v, 2) for v in vs]))
env_tensors = environment(bpc, [(v, 2) for v in vs])
Expand All @@ -136,39 +73,5 @@ using Test: @test, @testset
eigs = eigvals(rdm)
@test size(rdm) == (2^length(vs), 2^length(vs))
@test all(>=(0), real(eigs)) && all(==(0), imag(eigs))

#Test more advanced block BP with MPS message tensors on a grid
g_dims = (4, 3)
g = named_grid(g_dims)
s = siteinds("S=1/2", g)
χ = 2
ψ = random_tensornetwork(s; link_space=χ)
v = (2, 2)

ψψ = flatten_networks(ψ, dag(ψ); combine_linkinds=false, map_bra_linkinds=prime)
Oψ = copy(ψ)
Oψ[v] = apply(op("Sz", s[v]), ψ[v])
ψOψ = flatten_networks(ψ, dag(Oψ); combine_linkinds=false, map_bra_linkinds=prime)

combiners = linkinds_combiners(ψψ)
ψψ = combine_linkinds(ψψ, combiners)
ψOψ = combine_linkinds(ψOψ, combiners)

bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
message_update_func(tns; kwargs...) = collect(
eachtensor(first(contract(ITensorNetwork(tns); alg="density_matrix", kwargs...)))
)
bpc = update(
bpc; message_update=message_update_func, message_update_kwargs=(; cutoff=1e-6, maxdim=4)
)

env_tensors = environment(bpc, [v])
numerator = contract(vcat(env_tensors, ITensor[ψOψ[v]]))[]
denominator = contract(vcat(env_tensors, ITensor[ψψ[v]]))[]

exact_sz =
contract_boundary_mps(ψOψ; cutoff=1e-16) / contract_boundary_mps(ψψ; cutoff=1e-16)

@test abs.((numerator / denominator) - exact_sz) <= 1e-5
end
end
Loading
Loading