Skip to content

Commit

Permalink
Update gauge fixing to work with symmetries (#19)
Browse files Browse the repository at this point in the history
* Update gauge fixing to work with symmetries

* Add `eigsolve` rrule from Jutho/KylovKit.jl#56

* For some reason, this seems to work, don't touch anything

* Fix gauge-fixing for larger unit cells, decrease default tol for element-wise convergence check

* Add (incomplete) VectorInterface for CTMRGEnv to fix GMRES fixed-point solver

* Check that gauge-fixing also works for Float64 tensors

* Add important type annotation to CTMRGEnv VectorInterface

* Replace QR leftorth with tsvd to fix gauge_fix AD

* Cleaned up hacky rotate_north method that is not needed anymore

* Formatter

* Update CTMRGEnv VectorInterface implementation

* Make sure gaugefix happens with fixed spaces

* Remove debug code

---------

Co-authored-by: lkdvos <[email protected]>
Co-authored-by: Paul Brehmer <[email protected]>
  • Loading branch information
3 people authored Apr 16, 2024
1 parent a24ac36 commit 93e74a3
Show file tree
Hide file tree
Showing 9 changed files with 487 additions and 130 deletions.
74 changes: 60 additions & 14 deletions examples/test_gauge_fixing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,63 @@ using LinearAlgebra
using TensorKit, MPSKitModels, OptimKit
using PEPSKit

# Initialize PEPS and environment
χbond = 2
χenv = 20
ctmalg = CTMRG(; trscheme=truncdim(χenv), tol=1e-10, miniter=4, maxiter=100, verbosity=2)
ψ = InfinitePEPS(2, χbond; unitcell=(2, 2))
env = leading_boundary(ψ, ctmalg, CTMRGEnv(ψ; Venv=^χenv))

println("\nBefore gauge-fixing:")
env′, = PEPSKit.ctmrg_iter(ψ, env, ctmalg)
@show PEPSKit.check_elementwise_convergence(env, env′)

println("\nAfter gauge-fixing:")
envfix = PEPSKit.gauge_fix(env, env′)
@show PEPSKit.check_elementwise_convergence(env, envfix)
function test_gauge_fixing(
f, T, P::S, V::S, E::S; χenv::Int=20, unitcell::NTuple{2,Int}=(1, 1)
) where {S<:ElementarySpace}
ψ = InfinitePEPS(f, T, P, V; unitcell)
env = CTMRGEnv(ψ; Venv=E)

ctmalg = CTMRG(;
trscheme=truncdim(χenv), tol=1e-10, miniter=4, maxiter=100, verbosity=2
)
ctmalg_fixed = CTMRG(;
trscheme=truncdim(χenv),
tol=1e-10,
miniter=4,
maxiter=100,
verbosity=2,
fixedspace=true,
)

env = leading_boundary(ψ, ctmalg, env)

println("Testing gauge fixing for $(sectortype(P)) symmetry and $unitcell unit cell.")

println("\nBefore gauge-fixing:")
env′, = PEPSKit.ctmrg_iter(ψ, env, ctmalg_fixed)
@show PEPSKit.check_elementwise_convergence(env, env′)

println("\nAfter gauge-fixing:")
envfix = PEPSKit.gauge_fix(env, env′)
@show PEPSKit.check_elementwise_convergence(env, envfix)
return println()
end

# Trivial

P =^2 # physical space
V =^2 # PEPS virtual space
χenv = 20 # environment truncation dimension
E =^χenv # environment virtual space

test_gauge_fixing(randn, ComplexF64, P, V, E; χenv, unitcell=(1, 1))
test_gauge_fixing(randn, ComplexF64, P, V, E; χenv, unitcell=(2, 2))
test_gauge_fixing(randn, ComplexF64, P, V, E; χenv, unitcell=(3, 4)) # check gauge-fixing for unit cells > (2, 2)

# Convergence of real CTMRG seems to be more sensitive to initial guess
test_gauge_fixing(randn, Float64, P, V, E; χenv, unitcell=(1, 1))
test_gauge_fixing(randn, Float64, P, V, E; χenv, unitcell=(2, 2))
test_gauge_fixing(randn, Float64, P, V, E; χenv, unitcell=(3, 4))

# Z2

P = Z2Space(0 => 1, 1 => 1) # physical space
V = Z2Space(0 => 2, 1 => 2) # PEPS virtual space
χenv = 20 # environment truncation dimension
E = Z2Space(0 => χenv / 2, 1 => χenv / 2) # environment virtual space

test_gauge_fixing(randn, ComplexF64, P, V, E; χenv, unitcell=(1, 1))
test_gauge_fixing(randn, ComplexF64, P, V, E; χenv, unitcell=(2, 2))

test_gauge_fixing(randn, Float64, P, V, E; χenv, unitcell=(1, 1))
test_gauge_fixing(randn, Float64, P, V, E; χenv, unitcell=(2, 2))
2 changes: 2 additions & 0 deletions examples/test_gradients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ using LinearAlgebra
using TensorKit, MPSKitModels, OptimKit
using PEPSKit, KrylovKit

using Zygote

# Square lattice Heisenberg Hamiltonian
function square_lattice_heisenberg(; Jx=-1.0, Jy=1.0, Jz=-1.0)
Sx, Sy, Sz, _ = spinmatrices(1//2)
Expand Down
1 change: 1 addition & 0 deletions src/PEPSKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using TensorKit, KrylovKit, MPSKit, OptimKit
using ChainRulesCore, Zygote

include("utility/util.jl")
include("utility/eigsolve.jl")
include("utility/rotations.jl")

include("states/abstractpeps.jl")
Expand Down
187 changes: 94 additions & 93 deletions src/algorithms/ctmrg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,31 +68,65 @@ function MPSKit.leading_boundary(state, alg::CTMRG, envinit=CTMRGEnv(state))
ϵold = ϵ
end

env′, = ctmrg_iter(state, env, alg)
# do one final iteration that does not change the spaces
alg_fixed = CTMRG(;
alg.trscheme, alg.tol, alg.maxiter, alg.miniter, alg.verbosity, fixedspace=true
)
env′, = ctmrg_iter(state, env, alg_fixed)
envfix = gauge_fix(env, env′)
check_elementwise_convergence(env, envfix) ||
check_elementwise_convergence(env, envfix; atol=alg.tol^(3 / 4)) ||
@warn "CTMRG did not converge elementwise."
return envfix
end

# Fix gauge of corner end edge tensors from last and second last CTMRG iteration
function gauge_fix(envprev::CTMRGEnv{C,T}, envfinal::CTMRGEnv{C,T}) where {C,T}
# Compute gauge tensors by comparing signs
# First fix physical indices to (1, 1)
Tfixprev = map(x -> convert(Array, x)[:, 1, 1, :], envprev.edges)
Tfixfinal = map(x -> convert(Array, x)[:, 1, 1, :], envfinal.edges)
# Check if spaces in envprev and envfinal are the same
same_spaces = map(Iterators.product(axes(envfinal.edges)...)) do (dir, r, c)
space(envfinal.edges[dir, r, c]) == space(envprev.edges[dir, r, c]) &&
space(envfinal.corners[dir, r, c]) == space(envprev.corners[dir, r, c])
end
@assert all(same_spaces) "Spaces of envprev and envfinal are not the same"

# Try the "general" algorithm from https://arxiv.org/abs/2311.11894
signs = map(Iterators.product(axes(envfinal.edges)...)) do (dir, r, c)
if isodd(dir)
seqprev = prod(circshift(Tfixprev[dir, r, :], 1 - c))
seqfinal = prod(circshift(Tfixfinal[dir, r, :], 1 - c))
else
seqprev = prod(circshift(Tfixprev[dir, :, c], 1 - r))
seqfinal = prod(circshift(Tfixfinal[dir, :, c], 1 - r))
# Gather edge tensors and pretend they're InfiniteMPSs
if dir == NORTH
Tsprev = circshift(envprev.edges[dir, r, :], 1 - c)
Tsfinal = circshift(envfinal.edges[dir, r, :], 1 - c)
elseif dir == EAST
Tsprev = circshift(envprev.edges[dir, :, c], 1 - r)
Tsfinal = circshift(envfinal.edges[dir, :, c], 1 - r)
elseif dir == SOUTH
Tsprev = circshift(reverse(envprev.edges[dir, r, :]), c)
Tsfinal = circshift(reverse(envfinal.edges[dir, r, :]), c)
elseif dir == WEST
Tsprev = circshift(reverse(envprev.edges[dir, :, c]), r)
Tsfinal = circshift(reverse(envfinal.edges[dir, :, c]), r)
end

# Random MPS of same bond dimension
M = map(Tsfinal) do t
TensorMap(randn, scalartype(t), codomain(t) domain(t))
end

φ = sum(diag(seqfinal) ./ diag(seqprev)) / size(seqprev, 1) # Global sequence phase
σ = sign.(seqfinal[1, :] ./ seqprev[1, :]) * φ'
Tensor(diagm(σ), space(envprev.edges[1], 1) * space(envprev.edges[1], 1)')
# Find right fixed points of mixed transfer matrices
ρinit = TensorMap(
randn,
scalartype(T),
MPSKit._lastspace(Tsfinal[end])' MPSKit._lastspace(M[end])',
)
ρprev = eigsolve(TransferMatrix(Tsprev, M), ρinit, 1, :LM)[2][1]
ρfinal = eigsolve(TransferMatrix(Tsfinal, M), ρinit, 1, :LM)[2][1]

# Decompose and multiply
Up, _, Vp = tsvd(ρprev)
Uf, _, Vf = tsvd(ρfinal)
Qprev = Up * Vp
Qfinal = Uf * Vf
σ = Qprev * Qfinal'

return σ
end

cornersfix, edgesfix = fix_relative_phases(envfinal, signs)
Expand All @@ -111,107 +145,63 @@ function gauge_fix(envprev::CTMRGEnv{C,T}, envfinal::CTMRGEnv{C,T}) where {C,T}
return envfix
end

# Explicit unrolling of for loop from previous version to fix AD
# TODO: Does not yet properly work for Lx,Ly > 2
# Explicit fixing of relative phases (doing this compactly in a loop is annoying)
function fix_relative_phases(envfinal::CTMRGEnv, signs)
e1 = envfinal
σ1 = signs
C1 = map(Iterators.product(axes(e1.corners)[2:3]...)) do (r, c)
C1 = map(Iterators.product(axes(envfinal.corners)[2:3]...)) do (r, c)
@tensor Cfix[-1; -2] :=
σ1[WEST, _prev(r, end), c][-1 1] *
e1.corners[NORTHWEST, r, c][1; 2] *
conj(σ1[NORTH, r, c][-2 2])
signs[WEST, _prev(r, end), c][-1 1] *
envfinal.corners[NORTHWEST, r, c][1; 2] *
conj(signs[NORTH, r, c][-2 2])
end
T1 = map(Iterators.product(axes(e1.edges)[2:3]...)) do (r, c)
T1 = map(Iterators.product(axes(envfinal.edges)[2:3]...)) do (r, c)
@tensor Tfix[-1 -2 -3; -4] :=
σ1[NORTH, r, c][-1 1] *
e1.edges[NORTH, r, c][1 -2 -3; 2] *
conj(σ1[NORTH, r, _next(c, end)][-4 2])
signs[NORTH, r, c][-1 1] *
envfinal.edges[NORTH, r, c][1 -2 -3; 2] *
conj(signs[NORTH, r, _next(c, end)][-4 2])
end

e2 = rotate_north(envfinal, EAST)
σ2 = rotate_north(signs, EAST)
C2 = map(Iterators.product(axes(e2.corners)[2:3]...)) do (r, c)
C2 = map(Iterators.product(axes(envfinal.corners)[2:3]...)) do (r, c)
@tensor Cfix[-1; -2] :=
σ2[WEST, _prev(r, end), c][-1 1] *
e2.corners[NORTHWEST, r, c][1; 2] *
conj(σ2[NORTH, r, c][-2 2])
signs[NORTH, r, _next(c, end)][-1 1] *
envfinal.corners[NORTHEAST, r, c][1; 2] *
conj(signs[EAST, r, c][-2 2])
end
C2 = rotate_north(C2, WEST)
T2 = map(Iterators.product(axes(e2.edges)[2:3]...)) do (r, c)
T2 = map(Iterators.product(axes(envfinal.edges)[2:3]...)) do (r, c)
@tensor Tfix[-1 -2 -3; -4] :=
σ2[NORTH, r, c][-1 1] *
e2.edges[NORTH, r, c][1 -2 -3; 2] *
conj(σ2[NORTH, r, _next(c, end)][-4 2])
signs[EAST, r, c][-1 1] *
envfinal.edges[EAST, r, c][1 -2 -3; 2] *
conj(signs[EAST, _next(r, end), c][-4 2])
end
T2 = rotate_north(T2, WEST)

e3 = rotate_north(envfinal, SOUTH)
σ3 = rotate_north(signs, SOUTH)
C3 = map(Iterators.product(axes(e3.corners)[2:3]...)) do (r, c)
C3 = map(Iterators.product(axes(envfinal.corners)[2:3]...)) do (r, c)
@tensor Cfix[-1; -2] :=
σ3[WEST, _prev(r, end), c][-1 1] *
e3.corners[NORTHWEST, r, c][1; 2] *
conj(σ3[NORTH, r, c][-2 2])
signs[EAST, _next(r, end), c][-1 1] *
envfinal.corners[SOUTHEAST, r, c][1; 2] *
conj(signs[SOUTH, r, c][-2 2])
end
C3 = rotate_north(C3, SOUTH)
T3 = map(Iterators.product(axes(e3.edges)[2:3]...)) do (r, c)
T3 = map(Iterators.product(axes(envfinal.edges)[2:3]...)) do (r, c)
@tensor Tfix[-1 -2 -3; -4] :=
σ3[NORTH, r, c][-1 1] *
e3.edges[NORTH, r, c][1 -2 -3; 2] *
conj(σ3[NORTH, r, _next(c, end)][-4 2])
signs[SOUTH, r, c][-1 1] *
envfinal.edges[SOUTH, r, c][1 -2 -3; 2] *
conj(signs[SOUTH, r, _prev(c, end)][-4 2])
end
T3 = rotate_north(T3, SOUTH)

e4 = rotate_north(envfinal, WEST)
σ4 = rotate_north(signs, WEST)
C4 = map(Iterators.product(axes(e4.corners)[2:3]...)) do (r, c)
C4 = map(Iterators.product(axes(envfinal.corners)[2:3]...)) do (r, c)
@tensor Cfix[-1; -2] :=
σ4[WEST, _prev(r, end), c][-1 1] *
e4.corners[NORTHWEST, r, c][1; 2] *
conj(σ4[NORTH, r, c][-2 2])
signs[SOUTH, r, _prev(c, end)][-1 1] *
envfinal.corners[SOUTHWEST, r, c][1; 2] *
conj(signs[WEST, r, c][-2 2])
end
C4 = rotate_north(C4, EAST)
T4 = map(Iterators.product(axes(e4.edges)[2:3]...)) do (r, c)
T4 = map(Iterators.product(axes(envfinal.edges)[2:3]...)) do (r, c)
@tensor Tfix[-1 -2 -3; -4] :=
σ4[NORTH, r, c][-1 1] *
e4.edges[NORTH, r, c][1 -2 -3; 2] *
conj(σ4[NORTH, r, _next(c, end)][-4 2])
signs[WEST, r, c][-1 1] *
envfinal.edges[WEST, r, c][1 -2 -3; 2] *
conj(signs[WEST, _prev(r, end), c][-4 2])
end
T4 = rotate_north(T4, EAST)

return stack([C1, C2, C3, C4]; dims=1), stack([T1, T2, T3, T4]; dims=1)
end

# Semi-working version analogous to left_move with rotations
# function fix_relative_phases(envfinal::CTMRGEnv, signs)
# cornersfix = deepcopy(envfinal.corners)
# edgesfix = deepcopy(envfinal.edges)
# for _ in 1:4
# corners = map(Iterators.product(axes(envfinal.corners)[2:3]...)) do (r, c)
# @tensor Cfix[-1; -2] :=
# signs[WEST, _prev(r, end), c][-1 1] *
# envfinal.corners[NORTHWEST, r, c][1; 2] *
# conj(signs[NORTH, r, c][-2 2])
# end
# @diffset cornersfix[NORTHWEST, :, :] .= corners
# edges = map(Iterators.product(axes(envfinal.edges)[2:3]...)) do (r, c)
# @tensor Tfix[-1 -2 -3; -4] :=
# signs[NORTH, r, c][-1 1] *
# envfinal.edges[NORTH, r, c][1 -2 -3; 2] *
# conj(signs[NORTH, r, _next(c, end)][-4 2])
# end
# @diffset edgesfix[NORTH, :, :] .= edges

# # Rotate east-wards
# envfinal = rotate_north(envfinal, EAST)
# cornersfix = rotate_north(cornersfix, EAST)
# edgesfix = rotate_north(edgesfix, EAST)
# signs = rotate_north(signs, EAST) # TODO: Fix AD problem here
# end
# return cornersfix, edgesfix
# end

"""
check_elementwise_convergence(envfinal, envfix; atol=1e-6)
Expand All @@ -221,7 +211,6 @@ CTMRG environments are below some tolerance.
function check_elementwise_convergence(
envfinal::CTMRGEnv, envfix::CTMRGEnv; atol::Real=1e-6
)
# TODO: do we need both max and mean?
ΔC = envfinal.corners .- envfix.corners
ΔCmax = norm(ΔC, Inf)
ΔCmean = norm(ΔC)
Expand All @@ -232,6 +221,18 @@ function check_elementwise_convergence(
ΔTmean = norm(ΔT)
@debug "maxᵢⱼ|Tⁿ⁺¹ - Tⁿ|ᵢⱼ = $ΔTmax mean |Tⁿ⁺¹ - Tⁿ|ᵢⱼ = $ΔTmean"

# Check differences for all tensors in unit cell to debug properly
for (dir, r, c) in Iterators.product(axes(envfinal.edges)...)
@debug(
"$((dir, r, c)): all |Cⁿ⁺¹ - Cⁿ|ᵢⱼ < ϵ: ",
all(x -> abs(x) < atol, ΔC[dir, r, c].data),
)
@debug(
"$((dir, r, c)): all |Tⁿ⁺¹ - Tⁿ|ᵢⱼ < ϵ: ",
all(x -> abs(x) < atol, ΔT[dir, r, c].data),
)
end

return isapprox(ΔCmax, 0; atol) && isapprox(ΔTmax, 0; atol)
end

Expand Down
1 change: 1 addition & 0 deletions src/algorithms/peps_opt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ Evaluating the gradient of the cost function for CTMRG:
- With AD, the gradient is computed by differentiating the cost function with respect to the PEPS tensors, including computing the environment tensors.
- With explicit evaluation of the geometric sum, the gradient is computed by differentiating the cost function with the environment kept fixed, and then manually adding the gradient contributions from the environments.
=#
using Zygote: @showgrad

function ctmrg_gradient((peps, envs), H, alg::PEPSOptimize{NaiveAD})
E, g = withgradient(peps) do ψ
Expand Down
Loading

0 comments on commit 93e74a3

Please sign in to comment.