Skip to content

Commit

Permalink
Adapt contract(_updater) to new interface.
Browse files Browse the repository at this point in the history
  • Loading branch information
b-kloss committed Jan 20, 2024
1 parent 86dd746 commit 59f45c1
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 12 deletions.
1 change: 1 addition & 0 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ include(joinpath("Graphs", "abstractdatagraph.jl"))
include(joinpath("solvers", "eigsolve.jl"))
include(joinpath("solvers", "exponentiate.jl"))
include(joinpath("solvers", "dmrg_x.jl"))
include(joinpath("solvers", "contract.jl"))
include(joinpath("treetensornetworks", "abstracttreetensornetwork.jl"))
include(joinpath("treetensornetworks", "ttn.jl"))
include(joinpath("treetensornetworks", "opsum_to_ttn.jl"))
Expand Down
19 changes: 19 additions & 0 deletions src/solvers/contract.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
function contract_updater(
init;

Check warning on line 2 in src/solvers/contract.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/solvers/contract.jl:2:- init; src/solvers/contract.jl:3:- state!, src/solvers/contract.jl:4:- projected_operator!, src/solvers/contract.jl:5:- outputlevel, src/solvers/contract.jl:6:- which_sweep, src/solvers/contract.jl:7:- sweep_plan, src/solvers/contract.jl:8:- which_region_update, src/solvers/contract.jl:9:- region_kwargs, src/solvers/contract.jl:10:- updater_kwargs, src/solvers/contract.jl:11:- ) src/solvers/contract.jl:12:- v = ITensor(1.0) src/solvers/contract.jl:13:- projected_operator = projected_operator![] src/solvers/contract.jl:14:- for j in sites(projected_operator) src/solvers/contract.jl:15:- v *= projected_operator.psi0[j] src/solvers/contract.jl:16:- end src/solvers/contract.jl:17:- Hpsi0 = contract(projected_operator, v) src/solvers/contract.jl:18:- return Hpsi0, (;) src/solvers/contract.jl:19:- end src/solvers/contract.jl:2:+ init; src/solvers/contract.jl:3:+ state!, src/solvers/contract.jl:4:+ projected_operator!, src/solvers/contract.jl:5:+ outputlevel, src/solvers/contract.jl:6:+ which_sweep, src/solvers/contract.jl:7:+ sweep_plan, src/solvers/contract.jl:8:+ which_region_update, src/solvers/contract.jl:9:+ region_kwargs, src/solvers/contract.jl:10:+ updater_kwargs, src/solvers/contract.jl:11:+) src/solvers/contract.jl:12:+ v = ITensor(1.0) src/solvers/contract.jl:13:+ projected_operator = projected_operator![] src/solvers/contract.jl:14:+ for j in sites(projected_operator) src/solvers/contract.jl:15:+ v *= projected_operator.psi0[j] src/solvers/contract.jl:16:+ end src/solvers/contract.jl:17:+ Hpsi0 = contract(projected_operator, v) src/solvers/contract.jl:18:+ return Hpsi0, (;) src/solvers/contract.jl:19:+end
state!,
projected_operator!,
outputlevel,
which_sweep,
sweep_plan,
which_region_update,
region_kwargs,
updater_kwargs,
)
v = ITensor(1.0)
projected_operator = projected_operator![]
for j in sites(projected_operator)
v *= projected_operator.psi0[j]
end
Hpsi0 = contract(projected_operator, v)
return Hpsi0, (;)
end
14 changes: 4 additions & 10 deletions src/treetensornetworks/solvers/contract.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
function contract_solver(PH, psi; normalize, region, half_sweep)
v = ITensor(1.0)
for j in sites(PH)
v *= PH.psi0[j]
end
Hpsi0 = contract(PH, v)
return Hpsi0, NamedTuple()
end

function contract(
::Algorithm"fit",
tn1::AbstractTTN,
tn2::AbstractTTN;
init=random_ttn(flatten_external_indsnetwork(tn1, tn2); link_space=trivial_space(tn1)),
nsweeps=1,
nsites=2, # used to be default of call to default_sweep_regions
updater_kwargs=(;),
kwargs...,
)
n = nv(tn1)
Expand Down Expand Up @@ -42,7 +35,8 @@ function contract(
## end

PH = ProjTTNApply(tn2, tn1)
psi = alternating_update(contract_solver, PH, init; nsweeps, kwargs...)
sweep_plan = default_sweep_regions(nsites, init; kwargs...)
psi = alternating_update(contract_updater, PH, init; nsweeps, sweep_plan, updater_kwargs, kwargs...)

Check warning on line 39 in src/treetensornetworks/solvers/contract.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/treetensornetworks/solvers/contract.jl:39:- psi = alternating_update(contract_updater, PH, init; nsweeps, sweep_plan, updater_kwargs, kwargs...) src/treetensornetworks/solvers/contract.jl:39:+ psi = alternating_update( src/treetensornetworks/solvers/contract.jl:40:+ contract_updater, PH, init; nsweeps, sweep_plan, updater_kwargs, kwargs... src/treetensornetworks/solvers/contract.jl:41:+ )

return psi
end
Expand Down
4 changes: 2 additions & 2 deletions test/test_treetensornetworks/test_solvers/test_contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ using Test

# Test with nsite=1
Hpsi_guess = random_mps(t; internal_inds_space=32)
Hpsi = apply(H, psi; alg="fit", init=Hpsi_guess, nsite=1, nsweeps=4)
Hpsi = apply(H, psi; alg="fit", init=Hpsi_guess, nsites=1, nsweeps=4)
@test inner(psit, Hpsi) inner(psit, H, psi) atol = 1E-4
end

Expand Down Expand Up @@ -84,7 +84,7 @@ end

# Test with nsite=1
Hpsi_guess = random_ttn(t; link_space=4)
Hpsi = apply(H, psi; alg="fit", nsite=1, nsweeps=4, init=Hpsi_guess)
Hpsi = apply(H, psi; alg="fit", nsites=1, nsweeps=4, init=Hpsi_guess)
@test inner(psit, Hpsi) inner(psit, H, psi) atol = 1E-4
end

Expand Down

0 comments on commit 59f45c1

Please sign in to comment.