Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct adjoint for truncated SVD #15

Merged
merged 29 commits into from
Jul 10, 2024
Merged
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c390bec
Add truncated SVD adjoint with wrapper for KrylovKit iterative SVD, a…
pbrehmer Mar 4, 2024
51507ce
Use KrylovKit.linsolve for truncation linear problem, make loss funct…
pbrehmer Mar 5, 2024
d3a31fb
Improve loss function, compare SVD gradient with TensorKit.tsvd gradient
pbrehmer Mar 5, 2024
8ac307e
Update SVD adjoint linear problem to use Tuple and remove reshapes
pbrehmer Mar 26, 2024
80db1c0
Fix ZeroTangent case for linear problem
pbrehmer Apr 11, 2024
ae96298
Add SVD wrapper structs and function, utilize tsvd machinery, convert…
pbrehmer Jun 20, 2024
78ab152
Copy ctmrg.jl from master, add svdalg field to CTMRG, use svdwrap in …
pbrehmer Jun 20, 2024
dc6ad2b
Merge branch 'master' into svd_adjoint
lkdvos Jun 28, 2024
fc6600e
Use KrylovKit implementation of eigsolve instead of eigsolve.jl, dele…
pbrehmer Jul 4, 2024
823f52e
Add IterSVD _tsvd! method and adjoint using KrylovKit.svdsolve adjoint
pbrehmer Jul 5, 2024
a615b4a
Add PEPSKit.tsvd wrapper, fix IterSVD adjoint
pbrehmer Jul 5, 2024
a815e41
Add TensorKit compat entry for softened tsvd type restrictions
pbrehmer Jul 5, 2024
447bfd8
Add ProjectorAlg and refactor all tests and examples
pbrehmer Jul 5, 2024
bbd132d
Update MPSKit compat
lkdvos Jul 5, 2024
0c03cd4
Replace tsvd with tsvd!, add views to IterSVD adjoint
pbrehmer Jul 8, 2024
14f0065
Improve IterSVD allocation, implement CTMRG convenience constructor, …
pbrehmer Jul 8, 2024
9b2c4b7
Fix tests
pbrehmer Jul 8, 2024
0c13d47
Add block-wise dense fallback option
pbrehmer Jul 8, 2024
538652d
Add SVDrrule wrapper, add separate adjoint structs and rrules, update…
pbrehmer Jul 8, 2024
7032ed0
Add IterSVD test for symmetric tensor with fallback
pbrehmer Jul 9, 2024
66a827e
Merge branch 'master' into svd_adjoint
pbrehmer Jul 9, 2024
d09561f
Formatting
pbrehmer Jul 9, 2024
b3a0726
Fix missing cnext in ctmrg, update README example
pbrehmer Jul 9, 2024
89ae0a4
Rename DenseSVDAdjoint, update svd_wrapper test
pbrehmer Jul 9, 2024
25d198c
Make CRCExt extension backwards compatible with v1.8
pbrehmer Jul 9, 2024
6b818e7
Replace SVDrrule with SVDAdjoint, clean up adjoint algorithms
pbrehmer Jul 9, 2024
bb96664
Small cleanup
lkdvos Jul 9, 2024
fa7a56a
Update minimal julia version 1.9
lkdvos Jul 10, 2024
6df8efc
Remove duplicate line in left_move
pbrehmer Jul 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions src/utility/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,19 @@
# Wrapper around Krylov Kit's GKL iterative SVD solver
@kwdef struct IterSVD
alg::KrylovKit.GKL = KrylovKit.GKL(; tol=1e-14, krylovdim=25)
fallback_threshold::Float64 = Inf
lorentz_broad::Float64 = 0.0
alg_rrule::Union{GMRES,BiCGStab,Arnoldi} = GMRES(; tol=1e-14)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep this alg_rrule separate, maybe re-using the hook_pullback interface, to decouple the implementation of the forward and backward rules

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But in terms of usability, I think the SVD algorithm struct would be the best way to set the SVD adjoint parameters. So if I were to use hook_pullback in this case, I would need to store the alg_rrule and lorentz_broad settings in the CTMRG struct, which I also don't like that much.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, that's a very good point.

How about a wrapper then?

struct SVDrrule{F,T,R,E<:Real}
    svd_alg::F
    truncation_alg::T
    rrule_alg::R
    lorentz_broadening::E
end

In this case, we still at least semantically separate the rrule from the forward pass (which I really like), while also grouping them together, which I also quite like. It also looks to me to be quite flexible.

I think we should probably investigate (not just here, but throughout PEPSKit), if we want to include the truncation algorithm here or as a separate field. My guess would be that it's actually not a bad idea to keep them separate, considering we presumably might want to have varying truncspace settings throughout the unit cell, without changing the SVD algorithm.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have now added such a wrapper. It leads to the cool thing that you can easily choose the rrule with the corresponding algorithm struct and this decouples the forward SVD algorithm and its reverse rule. I decided to hide the lorentz_broadening away inside the rrule_alg since in general Lorentzian broadening might not be needed in every SVD implementation.

The only thing I'm unsure about is the naming; maybe something like SVDWithRrule, SVDWithAdjoint, DiffSVD or DifferentiableSVD would be more descriptive?

end

# Compute SVD data block-wise using KrylovKit algorithm
function TensorKit._tsvd!(
t, alg::IterSVD, trunc::Union{NoTruncation,TruncationSpace}, p::Real=2
t, alg::Union{IterSVD}, trunc::Union{NoTruncation,TruncationSpace}, p::Real=2
)
# early return
if isempty(blocksectors(t))
truncerr = zero(real(scalartype(t)))
return _empty_svdtensors(t)..., truncerr

Check warning on line 41 in src/utility/svd.jl

View check run for this annotation

Codecov / codecov/patch

src/utility/svd.jl#L40-L41

Added lines #L40 - L41 were not covered by tests
end

Udata, Σdata, Vdata, dims = _compute_svddata!(t, alg, trunc)
Expand All @@ -55,22 +56,30 @@
Udata = SectorDict{I,A}()
Vdata = SectorDict{I,A}()
dims = SectorDict{I,Int}()
local Sdata

Check warning on line 59 in src/utility/svd.jl

View check run for this annotation

Codecov / codecov/patch

src/utility/svd.jl#L59

Added line #L59 was not covered by tests
for (c, b) in blocks(t)
x₀ = randn(eltype(b), size(b, 1))
howmany = trunc isa NoTruncation ? minimum(size(b)) : blockdim(trunc.space, c)
S, lvecs, rvecs, info = KrylovKit.svdsolve(b, x₀, howmany, :LR, alg.alg)
if info.converged < howmany # Fall back to dense SVD if not properly converged

if howmany / minimum(size(b)) > alg.fallback_threshold # Use dense SVD for small blocks
U, S, V = TensorKit.MatrixAlgebra.svd!(b, TensorKit.SVD())
Udata[c] = U
Vdata[c] = V
else # Slice in case more values were converged than requested
Udata[c] = stack(view(lvecs, 1:howmany))
Vdata[c] = stack(conj, view(rvecs, 1:howmany); dims=1)
S = @view S[1:howmany]
Udata[c] = @view U[:, 1:howmany]
Vdata[c] = @view V[1:howmany, :]

Check warning on line 67 in src/utility/svd.jl

View check run for this annotation

Codecov / codecov/patch

src/utility/svd.jl#L65-L67

Added lines #L65 - L67 were not covered by tests
else
S, lvecs, rvecs, info = KrylovKit.svdsolve(b, x₀, howmany, :LR, alg.alg)
if info.converged < howmany # Fall back to dense SVD if not properly converged
U, S, V = TensorKit.MatrixAlgebra.svd!(b, TensorKit.SVD())
Udata[c] = @view U[:, 1:howmany]
Vdata[c] = @view V[1:howmany, :]

Check warning on line 73 in src/utility/svd.jl

View check run for this annotation

Codecov / codecov/patch

src/utility/svd.jl#L71-L73

Added lines #L71 - L73 were not covered by tests
pbrehmer marked this conversation as resolved.
Show resolved Hide resolved
else # Slice in case more values were converged than requested
Udata[c] = stack(view(lvecs, 1:howmany))
Vdata[c] = stack(conj, view(rvecs, 1:howmany); dims=1)
end
end

S = @view S[1:howmany]
if @isdefined Sdata # cannot easily infer the type of Σ, so use this construction
Sdata[c] = S

Check warning on line 82 in src/utility/svd.jl

View check run for this annotation

Codecov / codecov/patch

src/utility/svd.jl#L82

Added line #L82 was not covered by tests
else
Sdata = SectorDict(c => S)
end
Expand Down Expand Up @@ -103,8 +112,8 @@
minimal_info = KrylovKit.ConvergenceInfo(length(Sdc), nothing, nothing, -1, -1) # Just supply converged to SVD pullback

if ΔUc isa AbstractZero && ΔVc isa AbstractZero # Handle ZeroTangent singular vectors
Δlvecs = fill(ZeroTangent(), n_vals)
Δrvecs = fill(ZeroTangent(), n_vals)

Check warning on line 116 in src/utility/svd.jl

View check run for this annotation

Codecov / codecov/patch

src/utility/svd.jl#L115-L116

Added lines #L115 - L116 were not covered by tests
else
Δlvecs = Vector{Vector{scalartype(t)}}(eachcol(ΔUc))
Δrvecs = Vector{Vector{scalartype(t)}}(eachcol(ΔVc'))
Expand All @@ -130,7 +139,7 @@
end
return NoTangent(), Δt, NoTangent()
end
function tsvd_itersvd_pullback(::Tuple{ZeroTangent,ZeroTangent,ZeroTangent})

Check warning on line 142 in src/utility/svd.jl

View check run for this annotation

Codecov / codecov/patch

src/utility/svd.jl#L142

Added line #L142 was not covered by tests
return NoTangent(), ZeroTangent(), NoTangent()
end

Expand Down Expand Up @@ -169,7 +178,7 @@
end
return NoTangent(), Δt, NoTangent()
end
function tsvd_oldsvd_pullback(::Tuple{ZeroTangent,ZeroTangent,ZeroTangent})

Check warning on line 181 in src/utility/svd.jl

View check run for this annotation

Codecov / codecov/patch

src/utility/svd.jl#L181

Added line #L181 was not covered by tests
return NoTangent(), ZeroTangent(), NoTangent()
end

Expand Down Expand Up @@ -235,7 +244,7 @@
end

# Lorentzian broadening for SVD adjoint F-singularities
function _lorentz_broaden(x::Real, ε=1e-12)
x′ = 1 / x
return x′ / (x′^2 + ε)

Check warning on line 249 in src/utility/svd.jl

View check run for this annotation

Codecov / codecov/patch

src/utility/svd.jl#L247-L249

Added lines #L247 - L249 were not covered by tests
end
Loading