Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor interfaces built around alternating update #121

Merged
merged 32 commits into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
ac76123
Change required solver function signature. Passes psi and projected H…
b-kloss Jan 11, 2024
b23ab06
Move solver_funcs to separate directory. Adapt dmrg and dmrg-x to new…
b-kloss Jan 16, 2024
c2c5455
Format.
b-kloss Jan 16, 2024
ed0df91
Modify alternating update kwarg naming, structure, and also solver_in…
b-kloss Jan 17, 2024
db524f7
Adapt eigsolve and dmrg to new interfaces, fix dmrg tests (and remove…
b-kloss Jan 17, 2024
d258ccc
Remove default_sweep_regions from dmrg.
b-kloss Jan 17, 2024
6ff116a
Format.
b-kloss Jan 17, 2024
0ac2f2d
rename region_[update]_printer and nsite[s]
Jan 19, 2024
c02662c
Remove applyexp.jl
Jan 19, 2024
1a8ff33
Fix imports/namespaces.
Jan 19, 2024
80620b8
Add ToDo regarding testing applyexp in test_tdvp.jl
Jan 19, 2024
4c33d65
Define sweep_params outside of function call.
Jan 19, 2024
114ef3e
Change NamedTuple access pattern.
Jan 19, 2024
b817587
Change NamedTuple(;) to (;).
Jan 19, 2024
723fae9
(;) also for tdvp.
Jan 19, 2024
1e40a79
Remove second applyexp.jl file.
Jan 19, 2024
5587391
Remove applyexp from ITensorNetworks.jl
Jan 19, 2024
0620943
Start renaming. One tdvp test not passing, observer related.
Jan 19, 2024
7865f81
Cleanup solvers.
b-kloss Jan 20, 2024
e08ba96
Rename to sweep_plan.
b-kloss Jan 20, 2024
30455a4
Fix renaming to tdvp_sweep and kwarg handling in updaters.
b-kloss Jan 20, 2024
86dd746
Fix dmrg-x.
b-kloss Jan 20, 2024
59f45c1
Adapt contract(_updater) to new interface.
b-kloss Jan 20, 2024
8bc62fb
Format.
b-kloss Jan 20, 2024
ac1a923
Fix tdvp_time_dependent tests.
b-kloss Jan 21, 2024
2748927
Format tests.
b-kloss Jan 21, 2024
1fe92b3
Remove obsolete tests from test_tdvp (mostly those with an alternativ…
b-kloss Jan 21, 2024
9e8e635
Remove exponentiate from imports from KrylovKit
b-kloss Jan 21, 2024
23cc545
Apply review suggestions.
b-kloss Jan 21, 2024
3d5ade8
Fix test_tdvp.jl
b-kloss Jan 21, 2024
64ef223
Format.
b-kloss Jan 21, 2024
812de92
Adapt linsolve.
b-kloss Jan 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ include("tensornetworkoperators.jl")
include(joinpath("ITensorsExt", "itensorutils.jl"))
include(joinpath("Graphs", "abstractgraph.jl"))
include(joinpath("Graphs", "abstractdatagraph.jl"))
include(joinpath("solvers", "eigsolve.jl"))
include(joinpath("solvers", "exponentiate.jl"))
include(joinpath("solvers", "dmrg_x.jl"))
include(joinpath("solvers", "contract.jl"))
include(joinpath("solvers", "linsolve.jl"))
include(joinpath("treetensornetworks", "abstracttreetensornetwork.jl"))
include(joinpath("treetensornetworks", "ttn.jl"))
include(joinpath("treetensornetworks", "opsum_to_ttn.jl"))
Expand All @@ -114,7 +119,6 @@ include(joinpath("treetensornetworks", "projttns", "projttn.jl"))
include(joinpath("treetensornetworks", "projttns", "projttnsum.jl"))
include(joinpath("treetensornetworks", "projttns", "projttn_apply.jl"))
include(joinpath("treetensornetworks", "solvers", "solver_utils.jl"))
include(joinpath("treetensornetworks", "solvers", "applyexp.jl"))
include(joinpath("treetensornetworks", "solvers", "update_step.jl"))
include(joinpath("treetensornetworks", "solvers", "alternating_update.jl"))
include(joinpath("treetensornetworks", "solvers", "tdvp.jl"))
Expand Down
19 changes: 19 additions & 0 deletions src/solvers/contract.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
function contract_updater(
init;
state!,
projected_operator!,
outputlevel,
which_sweep,
sweep_plan,
which_region_update,
region_kwargs,
updater_kwargs,
)
v = ITensor(true)
projected_operator = projected_operator![]
for j in sites(projected_operator)
v *= projected_operator.psi0[j]
end
vp = contract(projected_operator, v)
return vp, (;)
end
22 changes: 22 additions & 0 deletions src/solvers/dmrg_x.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
function dmrg_x_updater(
init;
state!,
projected_operator!,
outputlevel,
which_sweep,
sweep_plan,
which_region_update,
region_kwargs,
updater_kwargs,
)
# this updater does not seem to accept any kwargs?
default_updater_kwargs = (;)
updater_kwargs = merge(default_updater_kwargs, updater_kwargs)
H = contract(projected_operator![], ITensor(true))
D, U = eigen(H; ishermitian=true)
u = uniqueind(U, H)
max_overlap, max_ind = findmax(abs, array(dag(init) * U))
U_max = U * dag(onehot(u => max_ind))
# TODO: improve this to return the energy estimate too
return U_max, (;)
end
33 changes: 33 additions & 0 deletions src/solvers/eigsolve.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
function eigsolve_updater(
init;
state!,
projected_operator!,
outputlevel,
which_sweep,
sweep_plan,
which_region_update,
region_kwargs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think region_update_kwargs sounds better to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though the same discussion on exponentiate_updater applies here, I think we should discuss how these are being passed and maybe merge them with updater_kwargs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this will be obsolete soon, I am in favor of sticking with region_kwargs.

updater_kwargs,
)
default_updater_kwargs = (;
which_eigval=:SR,
ishermitian=true,
tol=1e-14,
krylovdim=3,
maxiter=1,
verbosity=0,
eager=false,
)
updater_kwargs = merge(default_updater_kwargs, updater_kwargs) #last collection has precedence
howmany = 1
(; which_eigval) = updater_kwargs
updater_kwargs = Base.structdiff(updater_kwargs, (; which_eigval=nothing))
vals, vecs, info = eigsolve(
projected_operator![], init, howmany, which_eigval; updater_kwargs...
)
return vecs[1], (; info, eigvals=vals)
end

function _pop_which_eigenvalue(; which_eigenvalue, kwargs...)
return which_eigenvalue, NamedTuple(kwargs)
end
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
27 changes: 27 additions & 0 deletions src/solvers/exponentiate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
function exponentiate_updater(
init;
state!,
projected_operator!,
outputlevel,
which_sweep,
sweep_plan,
which_region_update,
region_kwargs,
updater_kwargs,
Comment on lines +9 to +10
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the plan to combine these into updater_kwargs?

Copy link
Contributor Author

@b-kloss b-kloss Jan 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updater_kwargs are updater specific kwargs, while region_args are among the things that we expose to all updaters. in principle, we can nest the region_args into updater_kwargs in the call to region_update but I am not sure if that's preferable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm ok, I just have found it hard to keep track of the logic of why certain keyword arguments are bundled in certain ways, how they will be used, etc.

For example, from the perspective of this function, the only argument I can see that is being used here from region_kwargs is time_step, which I don't think is really any different conceptually from the arguments being passed in updater_kwargs (it's just another thing being used by the solver/updater). So it makes sense to me to just bundle those together in one flat NamedTuple called updater_kwargs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, these will eventually be bundled together in an upcoming PR.

)
default_updater_kwargs = (;
krylovdim=30,
maxiter=100,
verbosity=0,
tol=1E-12,
ishermitian=true,
issymmetric=true,
eager=true,
)

updater_kwargs = merge(default_updater_kwargs, updater_kwargs) #last collection has precedence
result, exp_info = exponentiate(
projected_operator![], region_kwargs.time_step, init; updater_kwargs...
)
return result, (; info=exp_info)
end
22 changes: 22 additions & 0 deletions src/solvers/linsolve.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
function linsolve_updater(
init;
state!,
projected_operator!,
outputlevel,
which_sweep,
sweep_plan,
which_region_update,
region_kwargs,
updater_kwargs,
)
default_updater_kwargs = (;
ishermitian=false, tol=1E-14, krylovdim=30, maxiter=100, verbosity=0, a₀, a₁
)
updater_kwargs = merge(default_updater_kwargs, updater_kwargs)
P = projected_operator![]
(; a₀, a₁) = updater_kwargs
updater_kwargs = Base.structdiff(updater_kwargs, (; a₀=nothing, a₁=nothing))
b = dag(only(proj_mps(P)))
x, info = KrylovKit.linsolve(P, b, init, a₀, a₁; updater_kwargs...)
return x, (;)
end
87 changes: 47 additions & 40 deletions src/treetensornetworks/solvers/alternating_update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,83 +26,88 @@ function process_sweeps(
return maxdim, mindim, cutoff, noise, kwargs
end

function sweep_printer(; outputlevel, psi, sweep, sw_time)
function sweep_printer(; outputlevel, state, which_sweep, sw_time)
if outputlevel >= 1
print("After sweep ", sweep, ":")
print(" maxlinkdim=", maxlinkdim(psi))
print("After sweep ", which_sweep, ":")
print(" maxlinkdim=", maxlinkdim(state))
print(" cpu_time=", round(sw_time; digits=3))
println()
flush(stdout)
end
end

function alternating_update(
solver,
PH,
psi0::AbstractTTN;
updater,
projected_operator,
init_state::AbstractTTN;
checkdone=(; kws...) -> false,
outputlevel::Integer=0,
nsweeps::Integer=1,
(sweep_observer!)=observer(),
sweep_printer=sweep_printer,
write_when_maxdim_exceeds::Union{Int,Nothing}=nothing,
updater_kwargs,
kwargs...,
)
maxdim, mindim, cutoff, noise, kwargs = process_sweeps(nsweeps; kwargs...)

psi = copy(psi0)
state = copy(init_state)

insert_function!(sweep_observer!, "sweep_printer" => sweep_printer)
insert_function!(sweep_observer!, "sweep_printer" => sweep_printer) # FIX THIS

for sweep in 1:nsweeps
if !isnothing(write_when_maxdim_exceeds) && maxdim[sweep] > write_when_maxdim_exceeds
for which_sweep in 1:nsweeps
if !isnothing(write_when_maxdim_exceeds) &&
maxdim[which_sweep] > write_when_maxdim_exceeds
if outputlevel >= 2
println(
"write_when_maxdim_exceeds = $write_when_maxdim_exceeds and maxdim[sweep] = $(maxdim[sweep]), writing environment tensors to disk",
"write_when_maxdim_exceeds = $write_when_maxdim_exceeds and maxdim[which_sweep] = $(maxdim[which_sweep]), writing environment tensors to disk",
)
end
PH = disk(PH)
projected_operator = disk(projected_operator)
end

sweep_params = (;
maxdim=maxdim[which_sweep],
mindim=mindim[which_sweep],
cutoff=cutoff[which_sweep],
noise=noise[which_sweep],
)
sw_time = @elapsed begin
psi, PH = update_step(
solver,
PH,
psi;
state, projected_operator = sweep_update(
updater,
projected_operator,
state;
outputlevel,
sweep,
maxdim=maxdim[sweep],
mindim=mindim[sweep],
cutoff=cutoff[sweep],
noise=noise[sweep],
which_sweep,
sweep_params,
updater_kwargs,
kwargs...,
)
end

update!(sweep_observer!; psi, sweep, sw_time, outputlevel)
update!(sweep_observer!; state, which_sweep, sw_time, outputlevel)

checkdone(; psi, sweep, outputlevel, kwargs...) && break
checkdone(; state, which_sweep, outputlevel, kwargs...) && break
end
select!(sweep_observer!, Observers.DataFrames.Not("sweep_printer")) # remove sweep_printer
return psi
select!(sweep_observer!, Observers.DataFrames.Not("sweep_printer"))
return state
end

function alternating_update(solver, H::AbstractTTN, psi0::AbstractTTN; kwargs...)
check_hascommoninds(siteinds, H, psi0)
check_hascommoninds(siteinds, H, psi0')
function alternating_update(updater, H::AbstractTTN, init_state::AbstractTTN; kwargs...)
check_hascommoninds(siteinds, H, init_state)
check_hascommoninds(siteinds, H, init_state')
# Permute the indices to have a better memory layout
# and minimize permutations
H = ITensors.permute(H, (linkind, siteinds, linkind))
PH = ProjTTN(H)
return alternating_update(solver, PH, psi0; kwargs...)
projected_operator = ProjTTN(H)
return alternating_update(updater, projected_operator, init_state; kwargs...)
end

"""
tdvp(Hs::Vector{MPO},psi0::MPS,t::Number; kwargs...)
tdvp(Hs::Vector{MPO},psi0::MPS,t::Number, sweeps::Sweeps; kwargs...)
tdvp(Hs::Vector{MPO},init_state::MPS,t::Number; kwargs...)
tdvp(Hs::Vector{MPO},init_state::MPS,t::Number, sweeps::Sweeps; kwargs...)

Use the time dependent variational principle (TDVP) algorithm
to compute `exp(t*H)*psi0` using an efficient algorithm based
to compute `exp(t*H)*init_state` using an efficient algorithm based
on alternating optimization of the MPS tensors and local Krylov
exponentiation of H.

Expand All @@ -114,14 +119,16 @@ the set of MPOs [H1,H2,H3,..] is efficiently looped over at
each step of the algorithm when optimizing the MPS.

Returns:
* `psi::MPS` - time-evolved MPS
* `state::MPS` - time-evolved MPS
"""
function alternating_update(solver, Hs::Vector{<:AbstractTTN}, psi0::AbstractTTN; kwargs...)
function alternating_update(
updater, Hs::Vector{<:AbstractTTN}, init_state::AbstractTTN; kwargs...
)
for H in Hs
check_hascommoninds(siteinds, H, psi0)
check_hascommoninds(siteinds, H, psi0')
check_hascommoninds(siteinds, H, init_state)
check_hascommoninds(siteinds, H, init_state')
end
Hs .= ITensors.permute.(Hs, Ref((linkind, siteinds, linkind)))
PHs = ProjTTNSum(Hs)
return alternating_update(solver, PHs, psi0; kwargs...)
projected_operators = ProjTTNSum(Hs)
return alternating_update(updater, projected_operators, init_state; kwargs...)
end
Loading
Loading