diff --git a/src/environment.jl b/src/environment.jl index 37249cb3..f3c424c0 100644 --- a/src/environment.jl +++ b/src/environment.jl @@ -1,31 +1,33 @@ +using ITensors: contract +using NamedGraphs.PartitionedGraphs: PartitionedGraph + default_environment_algorithm() = "exact" function environment( - ψ::AbstractITensorNetwork, + tn::AbstractITensorNetwork, vertices::Vector; alg=default_environment_algorithm(), kwargs..., ) - return environment(Algorithm(alg), ψ, vertices; kwargs...) + return environment(Algorithm(alg), tn, vertices; kwargs...) end function environment( - ::Algorithm"exact", ψ::AbstractITensorNetwork, verts::Vector; kwargs... + ::Algorithm"exact", tn::AbstractITensorNetwork, verts::Vector; kwargs... ) - return [contract(subgraph(ψ, setdiff(vertices(ψ), verts)); kwargs...)] + return [contract(subgraph(tn, setdiff(vertices(tn), verts)); kwargs...)] end function environment( ::Algorithm"bp", - ψ::AbstractITensorNetwork, + ptn::PartitionedGraph, vertices::Vector; (cache!)=nothing, - partitioned_vertices=default_partitioned_vertices(ψ), update_cache=isnothing(cache!), cache_update_kwargs=default_cache_update_kwargs(cache!), ) if isnothing(cache!) - cache! = Ref(BeliefPropagationCache(ψ, partitioned_vertices)) + cache! = Ref(BeliefPropagationCache(ptn)) end if update_cache @@ -34,3 +36,13 @@ function environment( return environment(cache![], vertices) end + +function environment( + alg::Algorithm"bp", + tn::AbstractITensorNetwork, + vertices::Vector; + partitioned_vertices=default_partitioned_vertices(tn), + kwargs..., +) + return environment(alg, PartitionedGraph(tn, partitioned_vertices), vertices; kwargs...) +end diff --git a/src/expect.jl b/src/expect.jl index 1a63feaa..64dc78b3 100644 --- a/src/expect.jl +++ b/src/expect.jl @@ -9,7 +9,7 @@ function ITensorMPS.expect(ψIψ::AbstractFormNetwork, op::Op; contract_kwargs=( ψIψ_v = ψIψ[operator_vertex(ψIψ, v)] s = commonind(ψIψ[ket_vertex(ψIψ, v)], ψIψ_v) operator = ITensors.op(op.which_op, s) - ∂ψIψ_∂v = environment(ψIψ, [v]; vertex_mapping_function=operator_vertices, kwargs...) + ∂ψIψ_∂v = environment(ψIψ, operator_vertices(ψIψ, [v]); kwargs...) numerator = contract(vcat(∂ψIψ_∂v, operator); contract_kwargs...)[] denominator = contract(vcat(∂ψIψ_∂v, ψIψ_v); contract_kwargs...)[] diff --git a/src/formnetworks/abstractformnetwork.jl b/src/formnetworks/abstractformnetwork.jl index 200aa0f6..66776ec8 100644 --- a/src/formnetworks/abstractformnetwork.jl +++ b/src/formnetworks/abstractformnetwork.jl @@ -72,24 +72,6 @@ function operator_network(f::AbstractFormNetwork) ) end -function environment( - f::AbstractFormNetwork, - original_vertices::Vector; - vertex_mapping_function=state_vertices, - alg=default_environment_algorithm(), - kwargs..., -) - form_vertices = vertex_mapping_function(f, original_vertices) - if alg == "bp" - partitioned_vertices = group(v -> original_state_vertex(f, v), vertices(f)) - return environment( - tensornetwork(f), form_vertices; alg, partitioned_vertices, kwargs... - ) - else - return environment(tensornetwork(f), form_vertices; alg, kwargs...) - end -end - operator_vertex_map(f::AbstractFormNetwork) = v -> (v, operator_vertex_suffix(f)) bra_vertex_map(f::AbstractFormNetwork) = v -> (v, bra_vertex_suffix(f)) ket_vertex_map(f::AbstractFormNetwork) = v -> (v, ket_vertex_suffix(f)) diff --git a/test/test_forms.jl b/test/test_forms.jl index e0edb597..c36ab585 100644 --- a/test/test_forms.jl +++ b/test/test_forms.jl @@ -16,6 +16,7 @@ using ITensorNetworks: operator_network, random_tensornetwork, siteinds, + state_vertices, tensornetwork, union_all_inds, update @@ -57,16 +58,16 @@ using Random: Random @test underlying_graph(ket_network(qf)) == underlying_graph(ψket) @test underlying_graph(operator_network(qf)) == underlying_graph(A) - ∂qf_∂v = only(environment(qf, [v])) + ∂qf_∂v = only(environment(qf, state_vertices(qf, [v]))) @test (∂qf_∂v) * (qf[ket_vertex(qf, v)] * qf[bra_vertex(qf, v)]) ≈ contract(qf) - ∂qf_∂v_bp = environment(qf, [v]; alg="bp", update_cache=false) + ∂qf_∂v_bp = environment(qf, state_vertices(qf, [v]); alg="bp", update_cache=false) ∂qf_∂v_bp = contract(∂qf_∂v_bp) ∂qf_∂v_bp /= norm(∂qf_∂v_bp) ∂qf_∂v /= norm(∂qf_∂v) @test ∂qf_∂v_bp != ∂qf_∂v - ∂qf_∂v_bp = environment(qf, [v]; alg="bp", update_cache=true) + ∂qf_∂v_bp = environment(qf, state_vertices(qf, [v]); alg="bp", update_cache=true) ∂qf_∂v_bp = contract(∂qf_∂v_bp) ∂qf_∂v_bp /= norm(∂qf_∂v_bp) @test ∂qf_∂v_bp ≈ ∂qf_∂v