diff --git a/examples/test_gauge_fixing.jl b/examples/test_gauge_fixing.jl index ab05b5b5..6594cf23 100644 --- a/examples/test_gauge_fixing.jl +++ b/examples/test_gauge_fixing.jl @@ -2,17 +2,63 @@ using LinearAlgebra using TensorKit, MPSKitModels, OptimKit using PEPSKit -# Initialize PEPS and environment -χbond = 2 -χenv = 20 -ctmalg = CTMRG(; trscheme=truncdim(χenv), tol=1e-10, miniter=4, maxiter=100, verbosity=2) -ψ = InfinitePEPS(2, χbond; unitcell=(2, 2)) -env = leading_boundary(ψ, ctmalg, CTMRGEnv(ψ; Venv=ℂ^χenv)) - -println("\nBefore gauge-fixing:") -env′, = PEPSKit.ctmrg_iter(ψ, env, ctmalg) -@show PEPSKit.check_elementwise_convergence(env, env′) - -println("\nAfter gauge-fixing:") -envfix = PEPSKit.gauge_fix(env, env′) -@show PEPSKit.check_elementwise_convergence(env, envfix) +function test_gauge_fixing( + f, T, P::S, V::S, E::S; χenv::Int=20, unitcell::NTuple{2,Int}=(1, 1) +) where {S<:ElementarySpace} + ψ = InfinitePEPS(f, T, P, V; unitcell) + env = CTMRGEnv(ψ; Venv=E) + + ctmalg = CTMRG(; + trscheme=truncdim(χenv), tol=1e-10, miniter=4, maxiter=100, verbosity=2 + ) + ctmalg_fixed = CTMRG(; + trscheme=truncdim(χenv), + tol=1e-10, + miniter=4, + maxiter=100, + verbosity=2, + fixedspace=true, + ) + + env = leading_boundary(ψ, ctmalg, env) + + println("Testing gauge fixing for $(sectortype(P)) symmetry and $unitcell unit cell.") + + println("\nBefore gauge-fixing:") + env′, = PEPSKit.ctmrg_iter(ψ, env, ctmalg_fixed) + @show PEPSKit.check_elementwise_convergence(env, env′) + + println("\nAfter gauge-fixing:") + envfix = PEPSKit.gauge_fix(env, env′) + @show PEPSKit.check_elementwise_convergence(env, envfix) + return println() +end + +# Trivial + +P = ℂ^2 # physical space +V = ℂ^2 # PEPS virtual space +χenv = 20 # environment truncation dimension +E = ℂ^χenv # environment virtual space + +test_gauge_fixing(randn, ComplexF64, P, V, E; χenv, unitcell=(1, 1)) +test_gauge_fixing(randn, ComplexF64, P, V, E; χenv, unitcell=(2, 2)) +test_gauge_fixing(randn, ComplexF64, P, V, E; χenv, unitcell=(3, 4)) # check gauge-fixing for unit cells > (2, 2) + +# Convergence of real CTMRG seems to be more sensitive to initial guess +test_gauge_fixing(randn, Float64, P, V, E; χenv, unitcell=(1, 1)) +test_gauge_fixing(randn, Float64, P, V, E; χenv, unitcell=(2, 2)) +test_gauge_fixing(randn, Float64, P, V, E; χenv, unitcell=(3, 4)) + +# Z2 + +P = Z2Space(0 => 1, 1 => 1) # physical space +V = Z2Space(0 => 2, 1 => 2) # PEPS virtual space +χenv = 20 # environment truncation dimension +E = Z2Space(0 => χenv / 2, 1 => χenv / 2) # environment virtual space + +test_gauge_fixing(randn, ComplexF64, P, V, E; χenv, unitcell=(1, 1)) +test_gauge_fixing(randn, ComplexF64, P, V, E; χenv, unitcell=(2, 2)) + +test_gauge_fixing(randn, Float64, P, V, E; χenv, unitcell=(1, 1)) +test_gauge_fixing(randn, Float64, P, V, E; χenv, unitcell=(2, 2)) diff --git a/examples/test_gradients.jl b/examples/test_gradients.jl index ad504a34..561a933a 100644 --- a/examples/test_gradients.jl +++ b/examples/test_gradients.jl @@ -2,6 +2,8 @@ using LinearAlgebra using TensorKit, MPSKitModels, OptimKit using PEPSKit, KrylovKit +using Zygote + # Square lattice Heisenberg Hamiltonian function square_lattice_heisenberg(; Jx=-1.0, Jy=1.0, Jz=-1.0) Sx, Sy, Sz, _ = spinmatrices(1//2) diff --git a/src/PEPSKit.jl b/src/PEPSKit.jl index c187cca9..9bc31ae6 100644 --- a/src/PEPSKit.jl +++ b/src/PEPSKit.jl @@ -9,6 +9,7 @@ using TensorKit, KrylovKit, MPSKit, OptimKit using ChainRulesCore, Zygote include("utility/util.jl") +include("utility/eigsolve.jl") include("utility/rotations.jl") include("states/abstractpeps.jl") diff --git a/src/algorithms/ctmrg.jl b/src/algorithms/ctmrg.jl index 46bb93ed..ceeee3f6 100644 --- a/src/algorithms/ctmrg.jl +++ b/src/algorithms/ctmrg.jl @@ -68,31 +68,65 @@ function MPSKit.leading_boundary(state, alg::CTMRG, envinit=CTMRGEnv(state)) ϵold = ϵ end - env′, = ctmrg_iter(state, env, alg) + # do one final iteration that does not change the spaces + alg_fixed = CTMRG(; + alg.trscheme, alg.tol, alg.maxiter, alg.miniter, alg.verbosity, fixedspace=true + ) + env′, = ctmrg_iter(state, env, alg_fixed) envfix = gauge_fix(env, env′) - check_elementwise_convergence(env, envfix) || + check_elementwise_convergence(env, envfix; atol=alg.tol^(3 / 4)) || @warn "CTMRG did not converge elementwise." return envfix end # Fix gauge of corner end edge tensors from last and second last CTMRG iteration function gauge_fix(envprev::CTMRGEnv{C,T}, envfinal::CTMRGEnv{C,T}) where {C,T} - # Compute gauge tensors by comparing signs - # First fix physical indices to (1, 1) - Tfixprev = map(x -> convert(Array, x)[:, 1, 1, :], envprev.edges) - Tfixfinal = map(x -> convert(Array, x)[:, 1, 1, :], envfinal.edges) + # Check if spaces in envprev and envfinal are the same + same_spaces = map(Iterators.product(axes(envfinal.edges)...)) do (dir, r, c) + space(envfinal.edges[dir, r, c]) == space(envprev.edges[dir, r, c]) && + space(envfinal.corners[dir, r, c]) == space(envprev.corners[dir, r, c]) + end + @assert all(same_spaces) "Spaces of envprev and envfinal are not the same" + + # Try the "general" algorithm from https://arxiv.org/abs/2311.11894 signs = map(Iterators.product(axes(envfinal.edges)...)) do (dir, r, c) - if isodd(dir) - seqprev = prod(circshift(Tfixprev[dir, r, :], 1 - c)) - seqfinal = prod(circshift(Tfixfinal[dir, r, :], 1 - c)) - else - seqprev = prod(circshift(Tfixprev[dir, :, c], 1 - r)) - seqfinal = prod(circshift(Tfixfinal[dir, :, c], 1 - r)) + # Gather edge tensors and pretend they're InfiniteMPSs + if dir == NORTH + Tsprev = circshift(envprev.edges[dir, r, :], 1 - c) + Tsfinal = circshift(envfinal.edges[dir, r, :], 1 - c) + elseif dir == EAST + Tsprev = circshift(envprev.edges[dir, :, c], 1 - r) + Tsfinal = circshift(envfinal.edges[dir, :, c], 1 - r) + elseif dir == SOUTH + Tsprev = circshift(reverse(envprev.edges[dir, r, :]), c) + Tsfinal = circshift(reverse(envfinal.edges[dir, r, :]), c) + elseif dir == WEST + Tsprev = circshift(reverse(envprev.edges[dir, :, c]), r) + Tsfinal = circshift(reverse(envfinal.edges[dir, :, c]), r) + end + + # Random MPS of same bond dimension + M = map(Tsfinal) do t + TensorMap(randn, scalartype(t), codomain(t) ← domain(t)) end - φ = sum(diag(seqfinal) ./ diag(seqprev)) / size(seqprev, 1) # Global sequence phase - σ = sign.(seqfinal[1, :] ./ seqprev[1, :]) * φ' - Tensor(diagm(σ), space(envprev.edges[1], 1) * space(envprev.edges[1], 1)') + # Find right fixed points of mixed transfer matrices + ρinit = TensorMap( + randn, + scalartype(T), + MPSKit._lastspace(Tsfinal[end])' ← MPSKit._lastspace(M[end])', + ) + ρprev = eigsolve(TransferMatrix(Tsprev, M), ρinit, 1, :LM)[2][1] + ρfinal = eigsolve(TransferMatrix(Tsfinal, M), ρinit, 1, :LM)[2][1] + + # Decompose and multiply + Up, _, Vp = tsvd(ρprev) + Uf, _, Vf = tsvd(ρfinal) + Qprev = Up * Vp + Qfinal = Uf * Vf + σ = Qprev * Qfinal' + + return σ end cornersfix, edgesfix = fix_relative_phases(envfinal, signs) @@ -111,107 +145,63 @@ function gauge_fix(envprev::CTMRGEnv{C,T}, envfinal::CTMRGEnv{C,T}) where {C,T} return envfix end -# Explicit unrolling of for loop from previous version to fix AD -# TODO: Does not yet properly work for Lx,Ly > 2 +# Explicit fixing of relative phases (doing this compactly in a loop is annoying) function fix_relative_phases(envfinal::CTMRGEnv, signs) - e1 = envfinal - σ1 = signs - C1 = map(Iterators.product(axes(e1.corners)[2:3]...)) do (r, c) + C1 = map(Iterators.product(axes(envfinal.corners)[2:3]...)) do (r, c) @tensor Cfix[-1; -2] := - σ1[WEST, _prev(r, end), c][-1 1] * - e1.corners[NORTHWEST, r, c][1; 2] * - conj(σ1[NORTH, r, c][-2 2]) + signs[WEST, _prev(r, end), c][-1 1] * + envfinal.corners[NORTHWEST, r, c][1; 2] * + conj(signs[NORTH, r, c][-2 2]) end - T1 = map(Iterators.product(axes(e1.edges)[2:3]...)) do (r, c) + T1 = map(Iterators.product(axes(envfinal.edges)[2:3]...)) do (r, c) @tensor Tfix[-1 -2 -3; -4] := - σ1[NORTH, r, c][-1 1] * - e1.edges[NORTH, r, c][1 -2 -3; 2] * - conj(σ1[NORTH, r, _next(c, end)][-4 2]) + signs[NORTH, r, c][-1 1] * + envfinal.edges[NORTH, r, c][1 -2 -3; 2] * + conj(signs[NORTH, r, _next(c, end)][-4 2]) end - e2 = rotate_north(envfinal, EAST) - σ2 = rotate_north(signs, EAST) - C2 = map(Iterators.product(axes(e2.corners)[2:3]...)) do (r, c) + C2 = map(Iterators.product(axes(envfinal.corners)[2:3]...)) do (r, c) @tensor Cfix[-1; -2] := - σ2[WEST, _prev(r, end), c][-1 1] * - e2.corners[NORTHWEST, r, c][1; 2] * - conj(σ2[NORTH, r, c][-2 2]) + signs[NORTH, r, _next(c, end)][-1 1] * + envfinal.corners[NORTHEAST, r, c][1; 2] * + conj(signs[EAST, r, c][-2 2]) end - C2 = rotate_north(C2, WEST) - T2 = map(Iterators.product(axes(e2.edges)[2:3]...)) do (r, c) + T2 = map(Iterators.product(axes(envfinal.edges)[2:3]...)) do (r, c) @tensor Tfix[-1 -2 -3; -4] := - σ2[NORTH, r, c][-1 1] * - e2.edges[NORTH, r, c][1 -2 -3; 2] * - conj(σ2[NORTH, r, _next(c, end)][-4 2]) + signs[EAST, r, c][-1 1] * + envfinal.edges[EAST, r, c][1 -2 -3; 2] * + conj(signs[EAST, _next(r, end), c][-4 2]) end - T2 = rotate_north(T2, WEST) - e3 = rotate_north(envfinal, SOUTH) - σ3 = rotate_north(signs, SOUTH) - C3 = map(Iterators.product(axes(e3.corners)[2:3]...)) do (r, c) + C3 = map(Iterators.product(axes(envfinal.corners)[2:3]...)) do (r, c) @tensor Cfix[-1; -2] := - σ3[WEST, _prev(r, end), c][-1 1] * - e3.corners[NORTHWEST, r, c][1; 2] * - conj(σ3[NORTH, r, c][-2 2]) + signs[EAST, _next(r, end), c][-1 1] * + envfinal.corners[SOUTHEAST, r, c][1; 2] * + conj(signs[SOUTH, r, c][-2 2]) end - C3 = rotate_north(C3, SOUTH) - T3 = map(Iterators.product(axes(e3.edges)[2:3]...)) do (r, c) + T3 = map(Iterators.product(axes(envfinal.edges)[2:3]...)) do (r, c) @tensor Tfix[-1 -2 -3; -4] := - σ3[NORTH, r, c][-1 1] * - e3.edges[NORTH, r, c][1 -2 -3; 2] * - conj(σ3[NORTH, r, _next(c, end)][-4 2]) + signs[SOUTH, r, c][-1 1] * + envfinal.edges[SOUTH, r, c][1 -2 -3; 2] * + conj(signs[SOUTH, r, _prev(c, end)][-4 2]) end - T3 = rotate_north(T3, SOUTH) - e4 = rotate_north(envfinal, WEST) - σ4 = rotate_north(signs, WEST) - C4 = map(Iterators.product(axes(e4.corners)[2:3]...)) do (r, c) + C4 = map(Iterators.product(axes(envfinal.corners)[2:3]...)) do (r, c) @tensor Cfix[-1; -2] := - σ4[WEST, _prev(r, end), c][-1 1] * - e4.corners[NORTHWEST, r, c][1; 2] * - conj(σ4[NORTH, r, c][-2 2]) + signs[SOUTH, r, _prev(c, end)][-1 1] * + envfinal.corners[SOUTHWEST, r, c][1; 2] * + conj(signs[WEST, r, c][-2 2]) end - C4 = rotate_north(C4, EAST) - T4 = map(Iterators.product(axes(e4.edges)[2:3]...)) do (r, c) + T4 = map(Iterators.product(axes(envfinal.edges)[2:3]...)) do (r, c) @tensor Tfix[-1 -2 -3; -4] := - σ4[NORTH, r, c][-1 1] * - e4.edges[NORTH, r, c][1 -2 -3; 2] * - conj(σ4[NORTH, r, _next(c, end)][-4 2]) + signs[WEST, r, c][-1 1] * + envfinal.edges[WEST, r, c][1 -2 -3; 2] * + conj(signs[WEST, _prev(r, end), c][-4 2]) end - T4 = rotate_north(T4, EAST) return stack([C1, C2, C3, C4]; dims=1), stack([T1, T2, T3, T4]; dims=1) end -# Semi-working version analogous to left_move with rotations -# function fix_relative_phases(envfinal::CTMRGEnv, signs) -# cornersfix = deepcopy(envfinal.corners) -# edgesfix = deepcopy(envfinal.edges) -# for _ in 1:4 -# corners = map(Iterators.product(axes(envfinal.corners)[2:3]...)) do (r, c) -# @tensor Cfix[-1; -2] := -# signs[WEST, _prev(r, end), c][-1 1] * -# envfinal.corners[NORTHWEST, r, c][1; 2] * -# conj(signs[NORTH, r, c][-2 2]) -# end -# @diffset cornersfix[NORTHWEST, :, :] .= corners -# edges = map(Iterators.product(axes(envfinal.edges)[2:3]...)) do (r, c) -# @tensor Tfix[-1 -2 -3; -4] := -# signs[NORTH, r, c][-1 1] * -# envfinal.edges[NORTH, r, c][1 -2 -3; 2] * -# conj(signs[NORTH, r, _next(c, end)][-4 2]) -# end -# @diffset edgesfix[NORTH, :, :] .= edges - -# # Rotate east-wards -# envfinal = rotate_north(envfinal, EAST) -# cornersfix = rotate_north(cornersfix, EAST) -# edgesfix = rotate_north(edgesfix, EAST) -# signs = rotate_north(signs, EAST) # TODO: Fix AD problem here -# end -# return cornersfix, edgesfix -# end - """ check_elementwise_convergence(envfinal, envfix; atol=1e-6) @@ -221,7 +211,6 @@ CTMRG environments are below some tolerance. function check_elementwise_convergence( envfinal::CTMRGEnv, envfix::CTMRGEnv; atol::Real=1e-6 ) - # TODO: do we need both max and mean? ΔC = envfinal.corners .- envfix.corners ΔCmax = norm(ΔC, Inf) ΔCmean = norm(ΔC) @@ -232,6 +221,18 @@ function check_elementwise_convergence( ΔTmean = norm(ΔT) @debug "maxᵢⱼ|Tⁿ⁺¹ - Tⁿ|ᵢⱼ = $ΔTmax mean |Tⁿ⁺¹ - Tⁿ|ᵢⱼ = $ΔTmean" + # Check differences for all tensors in unit cell to debug properly + for (dir, r, c) in Iterators.product(axes(envfinal.edges)...) + @debug( + "$((dir, r, c)): all |Cⁿ⁺¹ - Cⁿ|ᵢⱼ < ϵ: ", + all(x -> abs(x) < atol, ΔC[dir, r, c].data), + ) + @debug( + "$((dir, r, c)): all |Tⁿ⁺¹ - Tⁿ|ᵢⱼ < ϵ: ", + all(x -> abs(x) < atol, ΔT[dir, r, c].data), + ) + end + return isapprox(ΔCmax, 0; atol) && isapprox(ΔTmax, 0; atol) end diff --git a/src/algorithms/peps_opt.jl b/src/algorithms/peps_opt.jl index edc99932..db30a5c8 100644 --- a/src/algorithms/peps_opt.jl +++ b/src/algorithms/peps_opt.jl @@ -72,6 +72,7 @@ Evaluating the gradient of the cost function for CTMRG: - With AD, the gradient is computed by differentiating the cost function with respect to the PEPS tensors, including computing the environment tensors. - With explicit evaluation of the geometric sum, the gradient is computed by differentiating the cost function with the environment kept fixed, and then manually adding the gradient contributions from the environments. =# +using Zygote: @showgrad function ctmrg_gradient((peps, envs), H, alg::PEPSOptimize{NaiveAD}) E, g = withgradient(peps) do ψ diff --git a/src/environments/ctmrgenv.jl b/src/environments/ctmrgenv.jl index c4788cef..9beb95a0 100644 --- a/src/environments/ctmrgenv.jl +++ b/src/environments/ctmrgenv.jl @@ -84,7 +84,7 @@ end function LinearAlgebra.axpby!(α::Number, e₁::CTMRGEnv, β::Number, e₂::CTMRGEnv) e₂.corners .= α * e₁.corners + β * e₂.corners - e₂.edges .+= α * e₁.edges + β * e₂.edges + e₂.edges .= α * e₁.edges + β * e₂.edges return e₂ end @@ -92,4 +92,67 @@ function LinearAlgebra.dot(e₁::CTMRGEnv, e₂::CTMRGEnv) return dot(e₁.corners, e₂.corners) + dot(e₁.edges, e₂.edges) end -LinearAlgebra.norm(e::CTMRGEnv) = norm(e.corners) + norm(e.edges) +# VectorInterface +# --------------- + +# Note: the following methods consider the environment tensors as separate components of one +# big vector. In other words, the associated vector space is not the natural one associated +# to the original (physical) system, and addition, scaling, etc. are performed element-wise. + +import VectorInterface as VI + +function VI.scalartype(::Type{CTMRGEnv{C,T}}) where {C,T} + S₁ = scalartype(C) + S₂ = scalartype(T) + return promote_type(S₁, S₂) +end + +function VI.zerovector(env::CTMRGEnv, ::Type{S}) where {S<:Number} + _zerovector = Base.Fix2(zerovector, S) + return CTMRGEnv(map(_zerovector, env.corners), map(_zerovector, env.edges)) +end +function VI.zerovector!(env::CTMRGEnv) + foreach(zerovector!, env.corners) + foreach(zerovector!, env.edges) + return env +end +VI.zerovector!!(env::CTMRGEnv) = zerovector!(env) + +function VI.scale(env::CTMRGEnv, α::Number) + _scale = Base.Fix2(scale, α) + return CTMRGEnv(map(_scale, env.corners), map(_scale, env.edges)) +end +function VI.scale!(env::CTMRGEnv, α::Number) + _scale! = Base.Fix2(scale!, α) + foreach(_scale!, env.corners) + foreach(_scale!, env.edges) + return env +end +function VI.scale!(env₁::CTMRGEnv, env₂::CTMRGEnv, α::Number) + _scale!(x, y) = scale!(x, y, α) + foreach(_scale!, env₁.corners, env₂.corners) + foreach(_scale!, env₁.edges, env₂.edges) + return env₁ +end +VI.scale!!(env::CTMRGEnv, α::Number) = scale!(env, α) +VI.scale!!(env₁::CTMRGEnv, env₂::CTMRGEnv, α::Number) = scale!(env₁, env₂, α) + +function VI.add(env₁::CTMRGEnv, env₂::CTMRGEnv, α::Number, β::Number) + _add(x, y) = add(x, y, α, β) + return CTMRGEnv( + map(_add, env₁.corners, env₂.corners), map(_add, env₁.corners, env₂.corners) + ) +end +function VI.add!(env₁::CTMRGEnv, env₂::CTMRGEnv, α::Number, β::Number) + _add!(x, y) = add!(x, y, α, β) + foreach(_add!, env₁.corners, env₂.corners) + foreach(_add!, env₁.edges, env₂.edges) + return env₁ +end +VI.add!!(env₁::CTMRGEnv, env₂::CTMRGEnv, α::Number, β::Number) = add!(env₁, env₂, α, β) + +# exploiting the fact that vectorinterface works for tuples: +function VI.inner(env₁::CTMRGEnv, env₂::CTMRGEnv) + return inner((env₁.corners, env₁.edges), (env₂.corners, env₂.edges)) +end +VI.norm(env::CTMRGEnv) = norm((env.corners, env.edges)) diff --git a/src/states/infinitepeps.jl b/src/states/infinitepeps.jl index 3d4c7ffb..f5594062 100644 --- a/src/states/infinitepeps.jl +++ b/src/states/infinitepeps.jl @@ -64,7 +64,7 @@ function InfinitePEPS(A::T; unitcell::Tuple{Int,Int}=(1, 1)) where {T<:PEPSTenso end """ - InfinitePEPS(Pspace, Nspace, [Espace]; unitcell=(1,1)) + InfinitePEPS(f=randn, T=ComplexF64, Pspace, Nspace, [Espace]; unitcell=(1,1)) Create an InfinitePEPS by specifying its spaces and unit cell. Spaces can be specified either via `Int` or via `ElementarySpace`. @@ -73,7 +73,18 @@ function InfinitePEPS( Pspace::S, Nspace::S, Espace::S=Nspace; unitcell::Tuple{Int,Int}=(1, 1) ) where {S<:Union{ElementarySpace,Int}} return InfinitePEPS( - fill(Pspace, unitcell), fill(Nspace, unitcell), fill(Espace, unitcell) + randn, + ComplexF64, + fill(Pspace, unitcell), + fill(Nspace, unitcell), + fill(Espace, unitcell), + ) +end +function InfinitePEPS( + f, T, Pspace::S, Nspace::S, Espace::S=Nspace; unitcell::Tuple{Int,Int}=(1, 1) +) where {S<:Union{ElementarySpace,Int}} + return InfinitePEPS( + f, T, fill(Pspace, unitcell), fill(Nspace, unitcell), fill(Espace, unitcell) ) end diff --git a/src/utility/eigsolve.jl b/src/utility/eigsolve.jl new file mode 100644 index 00000000..975adf49 --- /dev/null +++ b/src/utility/eigsolve.jl @@ -0,0 +1,251 @@ +# Copied from Jutho/KrylovKit.jl/pull/56, with minor tweaks + +function ChainRulesCore.rrule( + ::typeof(eigsolve), A::AbstractMatrix, x₀, howmany, which, alg +) + (vals, vecs, info) = eigsolve(A, x₀, howmany, which, alg) + project_A = ProjectTo(A) + T = eltype(vecs[1]) # will be real for real symmetric problems and complex otherwise + + function eigsolve_pullback(ΔX) + _Δvals = unthunk(ΔX[1]) + _Δvecs = unthunk(ΔX[2]) + + ∂self = NoTangent() + ∂x₀ = ZeroTangent() + ∂howmany = NoTangent() + ∂which = NoTangent() + ∂alg = NoTangent() + if _Δvals isa AbstractZero && _Δvecs isa AbstractZero + ∂A = ZeroTangent() + return ∂self, ∂A, ∂x₀, ∂howmany, ∂which, ∂alg + end + + if _Δvals isa AbstractZero + Δvals = fill(NoTangent(), length(Δvecs)) + else + Δvals = _Δvals + end + if _Δvecs isa AbstractZero + Δvecs = fill(NoTangent(), length(Δvals)) + else + Δvecs = _Δvecs + end + + @assert length(Δvals) == length(Δvecs) + @assert length(Δvals) <= length(vals) + + # Determine algorithm to solve linear problem + # TODO: Is there a better choice? Should we make this user configurable? + linalg = GMRES(; + tol=alg.tol, krylovdim=alg.krylovdim, maxiter=alg.maxiter, orth=alg.orth + ) + + ws = similar(vecs, length(Δvecs)) + for i in 1:length(Δvecs) + Δλ = Δvals[i] + Δv = Δvecs[i] + λ = vals[i] + v = vecs[i] + + # First threat special cases + if isa(Δv, AbstractZero) && isa(Δλ, AbstractZero) # no contribution + ws[i] = Δv # some kind of zero + continue + end + if isa(Δv, AbstractZero) && isa(alg, Lanczos) # simple contribution + ws[i] = Δλ * v + continue + end + + # General case : + if isa(Δv, AbstractZero) + b = RecursiveVec(zero(T) * v, T[Δλ]) + else + @assert isa(Δv, typeof(v)) + b = RecursiveVec(Δv, T[Δλ]) + end + + if i > 1 && + eltype(A) <: Real && + vals[i] == conj(vals[i - 1]) && + Δvals[i] == conj(Δvals[i - 1]) && + vecs[i] == conj(vecs[i - 1]) && + Δvecs[i] == conj(Δvecs[i - 1]) + ws[i] = conj(ws[i - 1]) + continue + end + + w, reverse_info = let λ = λ, v = v, Aᴴ = A' + linsolve(b, zero(T) * b, linalg) do x + x1, x2 = x + γ = 1 + # γ can be chosen freely and does not affect the solution theoretically + # The current choice guarantees that the extended matrix is Hermitian if A is + # TODO: is this the best choice in all cases? + y1 = axpy!(-γ * x2[], v, axpy!(-conj(λ), x1, A' * x1)) + y2 = T[-dot(v, x1)] + return RecursiveVec(y1, y2) + end + end + if info.converged >= i && reverse_info.converged == 0 + @warn "The cotangent linear problem did not converge, whereas the primal eigenvalue problem did." + end + ws[i] = w[1] + end + + if A isa StridedMatrix + ∂A = InplaceableThunk( + Ā -> _buildĀ!(Ā, ws, vecs), @thunk(_buildĀ!(zero(A), ws, vecs)) + ) + else + ∂A = @thunk(project_A(_buildĀ!(zero(A), ws, vecs))) + end + return ∂self, ∂A, ∂x₀, ∂howmany, ∂which, ∂alg + end + return (vals, vecs, info), eigsolve_pullback +end + +function _buildĀ!(Ā, ws, vs) + for i in 1:length(ws) + w = ws[i] + v = vs[i] + if !(w isa AbstractZero) + if eltype(Ā) <: Real && eltype(w) <: Complex + mul!(Ā, _realview(w), _realview(v)', -1, 1) + mul!(Ā, _imagview(w), _imagview(v)', -1, 1) + else + mul!(Ā, w, v', -1, 1) + end + end + end + return Ā +end +function _realview(v::AbstractVector{Complex{T}}) where {T} + return view(reinterpret(T, v), 2 * (1:length(v)) .- 1) +end +function _imagview(v::AbstractVector{Complex{T}}) where {T} + return view(reinterpret(T, v), 2 * (1:length(v))) +end + +function ChainRulesCore.rrule( + config::RuleConfig{>:HasReverseMode}, + ::typeof(eigsolve), + A::AbstractMatrix, + x₀, + howmany, + which, + alg, +) + return ChainRulesCore.rrule(eigsolve, A, x₀, howmany, which, alg) +end + +function ChainRulesCore.rrule( + config::RuleConfig{>:HasReverseMode}, ::typeof(eigsolve), f, x₀, howmany, which, alg +) + (vals, vecs, info) = eigsolve(f, x₀, howmany, which, alg) + resize!(vecs, howmany) + resize!(vals, howmany) + T = typeof(dot(vecs[1], vecs[1])) + f_pullbacks = map(x -> rrule_via_ad(config, f, x)[2], vecs) + function eigsolve_pullback(ΔX) + _Δvals = unthunk(ΔX[1]) + _Δvecs = unthunk(ΔX[2]) + + ∂self = NoTangent() + ∂x₀ = ZeroTangent() + ∂howmany = NoTangent() + ∂which = NoTangent() + ∂alg = NoTangent() + if _Δvals isa AbstractZero && _Δvecs isa AbstractZero + ∂A = ZeroTangent() + return (∂self, ∂A, ∂x₀, ∂howmany, ∂which, ∂alg) + end + + if _Δvals isa AbstractZero + Δvals = fill(NoTangent(), howmany) + else + Δvals = _Δvals + end + if _Δvecs isa AbstractZero + Δvecs = fill(NoTangent(), howmany) + else + Δvecs = _Δvecs + end + + # filter ZeroTangents, added compared to Jutho/KrylovKit.jl/pull/56 + Δvecs = filter(x -> !(x isa AbstractZero), Δvecs) + @assert length(Δvals) == length(Δvecs) + + # Determine algorithm to solve linear problem + # TODO: Is there a better choice? Should we make this user configurable? + linalg = GMRES(; + tol=alg.tol, + krylovdim=alg.krylovdim + 10, + maxiter=alg.maxiter * 10, + orth=alg.orth, + ) + # linalg = BiCGStab(; + # tol = alg.tol, + # maxiter = alg.maxiter*alg.krylovdim, + # ) + + ws = similar(Δvecs) + for i in 1:length(Δvecs) + Δλ = Δvals[i] + Δv = Δvecs[i] + λ = vals[i] + v = vecs[i] + + # First threat special cases + if isa(Δv, AbstractZero) && isa(Δλ, AbstractZero) # no contribution + ws[i] = 0 * v # some kind of zero + continue + end + if isa(Δv, AbstractZero) && isa(alg, Lanczos) # simple contribution + ws[i] = Δλ * v + continue + end + + # General case : + if isa(Δv, AbstractZero) + b = RecursiveVec(zero(T) * v, T[-Δλ]) + else + @assert isa(Δv, typeof(v)) + b = RecursiveVec(-Δv, T[-Δλ]) + end + + # TODO: is there any analogy to this for general vector-like user types + # if i > 1 && eltype(A) <: Real && + # vals[i] == conj(vals[i-1]) && Δvals[i] == conj(Δvals[i-1]) && + # vecs[i] == conj(vecs[i-1]) && Δvecs[i] == conj(Δvecs[i-1]) + # + # ws[i] = conj(ws[i-1]) + # continue + # end + + w, reverse_info = let λ = λ, v = v, fᴴ = x -> f_pullbacks[i](x)[2] + linsolve(b, zero(T) * b, linalg) do x + x1, x2 = x + γ = 1 + # γ can be chosen freely and does not affect the solution theoretically + # The current choice guarantees that the extended matrix is Hermitian if A is + # TODO: is this the best choice in all cases? + y1 = axpy!(-γ * x2[], v, axpy!(-conj(λ), x1, fᴴ(x1))) + y2 = T[-dot(v, x1)] + return RecursiveVec(y1, y2) + end + end + if info.converged >= i && reverse_info.converged == 0 + @warn "The cotangent linear problem ($i) did not converge, whereas the primal eigenvalue problem did." reverse_info b + end + ws[i] = w[1] + end + ∂f = f_pullbacks[1](ws[1])[1] + for i in 2:length(ws) + ∂f = VectorInterface.add!!(∂f, f_pullbacks[i](ws[i])[1]) + end + return ∂self, ∂f, ∂x₀, ∂howmany, ∂which, ∂alg + end + return (vals, vecs, info), eigsolve_pullback +end diff --git a/src/utility/rotations.jl b/src/utility/rotations.jl index 5b202bf9..983a8292 100644 --- a/src/utility/rotations.jl +++ b/src/utility/rotations.jl @@ -10,22 +10,3 @@ const SOUTHWEST = 4 # Rotate tensor to any direction by successive application of Base.rotl90 rotate_north(t, dir) = mod1(dir, 4) == NORTH ? t : rotate_north(rotl90(t), dir - 1) - -# Hacked version for AbstractArray{T,3} which doesn't need to overload rotl90 to avoid type piracy -function rotate_north(A::AbstractArray{T,3}, dir) where {T} - for _ in 1:(mod1(dir, size(A, 1)) - 1) - # Initialize copy with rotated sizes - A′ = Zygote.Buffer(Array{T,3}(undef, size(A, 1), size(A, 3), size(A, 2))) - for dir in 1:size(A, 1) - # A′[_prev(dir, size(A, 1)), :, :] = rotl90(A[dir, :, :]) - # throws setindex! error for non-symmetric unit cells - rA = rotl90(A[dir, :, :]) - for r in 1:size(A, 3), c in 1:size(A, 2) - A′[_prev(dir, size(A, 1)), r, c] = rA[r, c] - end - end - A = A′ - end - - return copy(A) -end