Skip to content

Commit

Permalink
further svd ad implementation and test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Nov 17, 2023
1 parent c568e17 commit 21e8f6f
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 59 deletions.
2 changes: 1 addition & 1 deletion ext/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap;
Uc, Σc, Vc = block(U, c), block(Σ, c), block(V, c)
ΔUc, ΔΣc, ΔVc = block(ΔU, c), block(ΔΣ, c), block(ΔV, c)
Σdc = view(Σc, diagind(Σc))
ΔΣdc = view(ΔΣc, diagind(ΔΣc))
ΔΣdc = (ΔΣdc isa AbstractZero) ? ΔΣdc : view(ΔΣdc, diagind(ΔΣdc))
copyto!(b, svd_pullback(Uc, Σdc, Vc, ΔUc, ΔΣdc, ΔVc))
end
return NoTangent(), Δt
Expand Down
3 changes: 1 addition & 2 deletions src/tensors/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,7 @@ function _truncate!(V::SectorVectorDict, trunc::TruncationSpace, p=2)
end
return V, truncerr
end
########################
########################

function _truncate!(V::SectorVectorDict, trunc::TruncationCutoff, p=2)
I = keytype(V)
S = real(eltype(valtype(V)))
Expand Down
71 changes: 15 additions & 56 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,36 +78,14 @@ end
# complex-valued svd?
# -------------------

# function _gaugefix!(U, V)
# s = LinearAlgebra.Diagonal(TensorKit._safesign.(diag(U)))
# rmul!(U, s)
# lmul!(s', V)
# return U, V
# end

# function _tsvd(t::AbstractTensorMap)
# U, S, V, ϵ = tsvd(t)
# for (c, b) in blocks(U)
# _gaugefix!(b, block(V, c))
# end
# return U, S, V, ϵ
# end

# svd_rev = Base.get_extension(TensorKit, :TensorKitChainRulesCoreExt).svd_rev

# function ChainRulesCore.rrule(::typeof(_tsvd), t::AbstractTensorMap)
# U, S, V, ϵ = _tsvd(t)
# function _tsvd_pullback((ΔU, ΔS, ΔV, Δϵ))
# ∂t = similar(t)
# for (c, b) in blocks(∂t)
# copyto!(b,
# svd_rev(block(U, c), block(S, c), block(V, c),
# block(ΔU, c), block(ΔS, c), block(ΔV, c)))
# end
# return NoTangent(), ∂t
# end
# return (U, S, V, ϵ), _tsvd_pullback
# end
function remove_svdgauge_depence!(ΔU, ΔV, U, S, V)
# simple implementation, assumes no degeneracies or zeros in singular values
gaugepart = U' * ΔU + V * ΔV'
for (c, b) in blocks(gaugepart)
mul!(block(ΔU, c), block(U, c), Diagonal(imag(diag(b))), -im, 1)
end
return ΔU, ΔV
end

# Tests
# -----
Expand Down Expand Up @@ -275,12 +253,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
ΔU = TensorMap(randn, scalartype(U), space(U))
ΔS = TensorMap(randn, scalartype(S), space(S))
ΔV = TensorMap(randn, scalartype(V), space(V))
if T <: Complex # remove gauge dependent components
gaugepart = U' * ΔU + V * ΔV'
for (c, b) in blocks(gaugepart)
mul!(block(ΔU, c), block(U, c), Diagonal(imag(diag(b))), -im, 1)
end
end
T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V)
test_rrule(tsvd, B; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0))

Vtrunc = spacetype(S)(TensorKit.SectorDict(c => ceil(Int, size(b, 1) / 2)
Expand All @@ -290,12 +263,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
ΔU = TensorMap(randn, scalartype(U), space(U))
ΔS = TensorMap(randn, scalartype(S), space(S))
ΔV = TensorMap(randn, scalartype(V), space(V))
if T <: Complex # remove gauge dependent components
gaugepart = U' * ΔU + V * ΔV'
for (c, b) in blocks(gaugepart)
mul!(block(ΔU, c), block(U, c), Diagonal(imag(diag(b))), -im, 1)
end
end
T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V)
test_rrule(tsvd, B; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0),
fkwargs=(; trunc=truncspace(Vtrunc)))
end
Expand All @@ -304,26 +272,17 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
ΔU = TensorMap(randn, scalartype(U), space(U))
ΔS = TensorMap(randn, scalartype(S), space(S))
ΔV = TensorMap(randn, scalartype(V), space(V))
if T <: Complex # remove gauge dependent components
gaugepart = U' * ΔU + V * ΔV'
for (c, b) in blocks(gaugepart)
mul!(block(ΔU, c), block(U, c), Diagonal(imag(diag(b))), -im, 1)
end
end
T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V)
test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0))

U, S, V, ϵ = tsvd(C; trunc=truncdim(2))
c, = argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])), blocks(S))
U, S, V, ϵ = tsvd(C; trunc=truncdim(2 * dim(c)))
ΔU = TensorMap(randn, scalartype(U), space(U))
ΔS = TensorMap(randn, scalartype(S), space(S))
ΔV = TensorMap(randn, scalartype(V), space(V))
if T <: Complex # remove gauge dependent components
gaugepart = U' * ΔU + V * ΔV'
for (c, b) in blocks(gaugepart)
mul!(block(ΔU, c), block(U, c), Diagonal(imag(diag(b))), -im, 1)
end
end
T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V)
test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0),
fkwargs=(; trunc=truncdim(2)))
fkwargs=(; trunc=truncdim(2 * dim(c))))
end
end
end

0 comments on commit 21e8f6f

Please sign in to comment.