Skip to content

Commit

Permalink
Separate logic out in map_eigvals and use default eigen inds setting
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Apr 15, 2024
1 parent 513921b commit 8e15f0b
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 33 deletions.
26 changes: 13 additions & 13 deletions src/ITensorsExtensions/ITensorsExtensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,26 +54,26 @@ invsqrt_diag(it::ITensor) = map_diag(inv ∘ sqrt, it)
pinv_diag(it::ITensor) = map_diag(pinv, it)
pinvsqrt_diag(it::ITensor) = map_diag(pinv sqrt, it)

function map_eigenvalues(
f::Function, A::ITensor, linds=Index[first(inds(A))]; ishermitian=false, kwargs...
)
if isdiag(A)
return map_diag(s -> f(s), A)
end

#TODO: Make this work for non-hermitian A
function eigendecomp(A::ITensor, linds, rinds; ishermitian=false, kwargs...)
@assert ishermitian
rinds = setdiff(inds(A), linds)

D, U = eigen(A, linds, rinds; ishermitian, kwargs...)
ul, ur = noncommonind(D, U), commonind(D, U)
ulnew = sim(ul)

Ul = replaceinds(U, (rinds..., ur), (linds..., ulnew))
Ul = replaceinds(U, vcat(rinds, ur), vcat(linds, ulnew))
D = replaceind(D, ul, ulnew)
return Ul, D, dag(U)
end

function map_eigvals(f::Function, A::ITensor, inds...; ishermitian=false, kwargs...)
if isdiag(A)
return map_diag(f, A)
end

D_mapped = map_diag(f, D)
D_mapped = replaceind(D_mapped, ul, ulnew)
Ul, D, Ur = eigendecomp(A, inds...; ishermitian, kwargs...)

return Ul * D_mapped * dag(U)
return Ul * map_diag(f, D) * Ur
end

# Analagous to `denseblocks`.
Expand Down
40 changes: 24 additions & 16 deletions src/apply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,24 @@ function simple_update_bp_full(o, ψ, v⃗; envs, (singular_values!)=nothing, ap
envs_v1 = filter(env -> hascommoninds(env, ψ[v⃗[1]]), envs)
envs_v2 = filter(env -> hascommoninds(env, ψ[v⃗[2]]), envs)
sqrt_envs_v1 = [
ITensorsExtensions.map_eigenvalues(sqrt, env; cutoff, ishermitian=true) for
env in envs_v1
ITensorsExtensions.map_eigvals(
sqrt, env, first(inds(env)), last(inds(env)); cutoff, ishermitian=true
) for env in envs_v1
]
sqrt_envs_v2 = [
ITensorsExtensions.map_eigenvalues(sqrt, env; cutoff, ishermitian=true) for
env in envs_v2
ITensorsExtensions.map_eigvals(
sqrt, env, first(inds(env)), last(inds(env)); cutoff, ishermitian=true
) for env in envs_v2
]
inv_sqrt_envs_v1 = [
ITensorsExtensions.map_eigenvalues(inv sqrt, env; cutoff, ishermitian=true) for
env in envs_v1
ITensorsExtensions.map_eigvals(
inv sqrt, env, first(inds(env)), last(inds(env)); cutoff, ishermitian=true
) for env in envs_v1
]
inv_sqrt_envs_v2 = [
ITensorsExtensions.map_eigenvalues(inv sqrt, env; cutoff, ishermitian=true) for
env in envs_v2
ITensorsExtensions.map_eigvals(
inv sqrt, env, first(inds(env)), last(inds(env)); cutoff, ishermitian=true
) for env in envs_v2
]
ψᵥ₁ᵥ₂_tn = [ψ[v⃗[1]]; ψ[v⃗[2]]; sqrt_envs_v1; sqrt_envs_v2]
ψᵥ₁ᵥ₂ = contract(ψᵥ₁ᵥ₂_tn; sequence=contraction_sequence(ψᵥ₁ᵥ₂_tn; alg="optimal"))
Expand Down Expand Up @@ -131,20 +135,24 @@ function simple_update_bp(o, ψ, v⃗; envs, (singular_values!)=nothing, apply_k
envs_v1 = filter(env -> hascommoninds(env, ψ[v⃗[1]]), envs)
envs_v2 = filter(env -> hascommoninds(env, ψ[v⃗[2]]), envs)
sqrt_envs_v1 = [
ITensorsExtensions.map_eigenvalues(sqrt, env; cutoff, ishermitian=true) for
env in envs_v1
ITensorsExtensions.map_eigvals(
sqrt, env, first(inds(env)), last(inds(env)); cutoff, ishermitian=true
) for env in envs_v1
]
sqrt_envs_v2 = [
ITensorsExtensions.map_eigenvalues(sqrt, env; cutoff, ishermitian=true) for
env in envs_v2
ITensorsExtensions.map_eigvals(
sqrt, env, first(inds(env)), last(inds(env)); cutoff, ishermitian=true
) for env in envs_v2
]
inv_sqrt_envs_v1 = [
ITensorsExtensions.map_eigenvalues(inv sqrt, env; cutoff, ishermitian=true) for
env in envs_v1
ITensorsExtensions.map_eigvals(
inv sqrt, env, first(inds(env)), last(inds(env)); cutoff, ishermitian=true
) for env in envs_v1
]
inv_sqrt_envs_v2 = [
ITensorsExtensions.map_eigenvalues(inv sqrt, env; cutoff, ishermitian=true) for
env in envs_v2
ITensorsExtensions.map_eigvals(
inv sqrt, env, first(inds(env)), last(inds(env)); cutoff, ishermitian=true
) for env in envs_v2
]
ψᵥ₁ = contract([ψ[v⃗[1]]; sqrt_envs_v1])
ψᵥ₂ = contract([ψ[v⃗[2]]; sqrt_envs_v2])
Expand Down
9 changes: 5 additions & 4 deletions test/test_itensorsextensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using ITensors:
replaceinds,
dir,
array
using ITensorNetworks.ITensorsExtensions: map_eigenvalues
using ITensorNetworks.ITensorsExtensions: map_eigvals
using ITensorNetworks: siteinds, random_tensornetwork
using NamedGraphs: named_grid
using Random
Expand All @@ -26,12 +26,13 @@ Random.seed!(1234)
for eltype in [Float64, ComplexF64]
for n in [2, 3, 5, 10]
i, j = Index(n, "i"), Index(n, "j")
linds, rinds = Index[i], Index[j]
A = randn(eltype, (n, n))
A = A * A'
P = ITensor(A, i, j)
sqrtP = map_eigenvalues(sqrt, P)
inv_P = dag(map_eigenvalues(inv, P))
inv_sqrtP = dag(map_eigenvalues(inv sqrt, P))
sqrtP = map_eigvals(sqrt, P, linds, rinds; ishermitian=true)
inv_P = dag(map_eigvals(inv, P, linds, rinds; ishermitian=true))
inv_sqrtP = dag(map_eigvals(inv sqrt, P, linds, rinds; ishermitian=true))

sqrtPdag = replaceind(dag(sqrtP), i, i')
P2 = replaceind(sqrtP * sqrtPdag, i', j)
Expand Down

0 comments on commit 8e15f0b

Please sign in to comment.