From 6ff0cd572c947e9b1ed3642e690b43233277beb0 Mon Sep 17 00:00:00 2001 From: Joey Date: Thu, 17 Oct 2024 14:56:22 +0100 Subject: [PATCH] Bug fix in current ortho. Change test --- .../alternating_update/region_update.jl | 45 ++++++++----------- .../test_solvers/test_dmrg.jl | 12 ++--- 2 files changed, 25 insertions(+), 32 deletions(-) diff --git a/src/solvers/alternating_update/region_update.jl b/src/solvers/alternating_update/region_update.jl index b92adc8c..97241c20 100644 --- a/src/solvers/alternating_update/region_update.jl +++ b/src/solvers/alternating_update/region_update.jl @@ -7,36 +7,27 @@ function current_ortho(sweep_plan, which_region_update) if !isa(region, AbstractEdge) && length(region) == 1 return only(current_verts) end - if which_region_update == length(regions) - # look back by one should be sufficient, but may be brittle? - overlapping_vertex = only( - intersect(current_verts, support(regions[which_region_update - 1])) - ) - return overlapping_vertex - else - # look forward - other_regions = filter( - x -> !(issetequal(x, current_verts)), support.(regions[(which_region_update + 1):end]) + # look forward + other_regions = filter( + x -> !(issetequal(x, current_verts)), support.(regions[(which_region_update + 1):end]) + ) + # find the first region that has overlapping support with current region + ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions) + if isnothing(ind) + # look backward + other_regions = reverse( + filter( + x -> !(issetequal(x, current_verts)), support.(regions[1:(which_region_update - 1)]) + ), ) - # find the first region that has overlapping support with current region ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions) - if isnothing(ind) - # look backward - other_regions = reverse( - filter( - x -> !(issetequal(x, current_verts)), - support.(regions[1:(which_region_update - 1)]), - ), - ) - ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions) - end - @assert !isnothing(ind) - future_verts = union(support(other_regions[ind])) - # return ortho_ceter as the vertex in current region that does not overlap with following one - overlapping_vertex = intersect(current_verts, future_verts) - nonoverlapping_vertex = only(setdiff(current_verts, overlapping_vertex)) - return nonoverlapping_vertex end + @assert !isnothing(ind) + future_verts = union(support(other_regions[ind])) + # return ortho_ceter as the vertex in current region that does not overlap with following one + overlapping_vertex = intersect(current_verts, future_verts) + nonoverlapping_vertex = only(setdiff(current_verts, overlapping_vertex)) + return nonoverlapping_vertex end function region_update( diff --git a/test/test_treetensornetworks/test_solvers/test_dmrg.jl b/test/test_treetensornetworks/test_solvers/test_dmrg.jl index cf8a1caf..004ec561 100644 --- a/test/test_treetensornetworks/test_solvers/test_dmrg.jl +++ b/test/test_treetensornetworks/test_solvers/test_dmrg.jl @@ -1,7 +1,7 @@ @eval module $(gensym()) using DataGraphs: edge_data, vertex_data using Dictionaries: Dictionary -using Graphs: nv, vertices +using Graphs: nv, vertices, uniform_tree using ITensorMPS: ITensorMPS using ITensorNetworks: ITensorNetworks, @@ -19,6 +19,7 @@ using ITensorNetworks.ITensorsExtensions: replace_vertices using ITensorNetworks.ModelHamiltonians: ModelHamiltonians using ITensors: ITensors using KrylovKit: eigsolve +using NamedGraphs: NamedGraph, rename_vertices using NamedGraphs.NamedGraphGenerators: named_comb_tree using Observers: observer using StableRNGs: StableRNG @@ -313,11 +314,12 @@ end nsites = 2 nsweeps = 10 - c = named_comb_tree((3, 2)) - s = siteinds("S=1/2", c) - os = ModelHamiltonians.heisenberg(c) - H = ttn(os, s) rng = StableRNG(1234) + g = NamedGraph(uniform_tree(10)) + g = rename_vertices(v -> (v, 1), g) + s = siteinds("S=1/2", g) + os = ModelHamiltonians.heisenberg(g) + H = ttn(os, s) psi = random_ttn(rng, s; link_space=5) e, psi = dmrg(H, psi; nsweeps, maxdim, nsites)