diff --git a/src/ITensorsExtensions/ITensorsExtensions.jl b/src/ITensorsExtensions/ITensorsExtensions.jl index 939426bc..185c87e5 100644 --- a/src/ITensorsExtensions/ITensorsExtensions.jl +++ b/src/ITensorsExtensions/ITensorsExtensions.jl @@ -54,26 +54,26 @@ invsqrt_diag(it::ITensor) = map_diag(inv ∘ sqrt, it) pinv_diag(it::ITensor) = map_diag(pinv, it) pinvsqrt_diag(it::ITensor) = map_diag(pinv ∘ sqrt, it) -function map_eigenvalues( - f::Function, A::ITensor, linds=Index[first(inds(A))]; ishermitian=false, kwargs... -) - if isdiag(A) - return map_diag(s -> f(s), A) - end - +#TODO: Make this work for non-hermitian A +function eigendecomp(A::ITensor, linds, rinds; ishermitian=false, kwargs...) @assert ishermitian - rinds = setdiff(inds(A), linds) - D, U = eigen(A, linds, rinds; ishermitian, kwargs...) ul, ur = noncommonind(D, U), commonind(D, U) ulnew = sim(ul) - Ul = replaceinds(U, (rinds..., ur), (linds..., ulnew)) + Ul = replaceinds(U, vcat(rinds, ur), vcat(linds, ulnew)) + D = replaceind(D, ul, ulnew) + return Ul, D, dag(U) +end + +function map_eigvals(f::Function, A::ITensor, inds...; ishermitian=false, kwargs...) + if isdiag(A) + return map_diag(f, A) + end - D_mapped = map_diag(f, D) - D_mapped = replaceind(D_mapped, ul, ulnew) + Ul, D, Ur = eigendecomp(A, inds...; ishermitian, kwargs...) - return Ul * D_mapped * dag(U) + return Ul * map_diag(f, D) * Ur end # Analagous to `denseblocks`. diff --git a/src/apply.jl b/src/apply.jl index 3f576a38..1f70bee9 100644 --- a/src/apply.jl +++ b/src/apply.jl @@ -86,20 +86,24 @@ function simple_update_bp_full(o, ψ, v⃗; envs, (singular_values!)=nothing, ap envs_v1 = filter(env -> hascommoninds(env, ψ[v⃗[1]]), envs) envs_v2 = filter(env -> hascommoninds(env, ψ[v⃗[2]]), envs) sqrt_envs_v1 = [ - ITensorsExtensions.map_eigenvalues(sqrt, env; cutoff, ishermitian=true) for - env in envs_v1 + ITensorsExtensions.map_eigvals( + sqrt, env, first(inds(env)), last(inds(env)); cutoff, ishermitian=true + ) for env in envs_v1 ] sqrt_envs_v2 = [ - ITensorsExtensions.map_eigenvalues(sqrt, env; cutoff, ishermitian=true) for - env in envs_v2 + ITensorsExtensions.map_eigvals( + sqrt, env, first(inds(env)), last(inds(env)); cutoff, ishermitian=true + ) for env in envs_v2 ] inv_sqrt_envs_v1 = [ - ITensorsExtensions.map_eigenvalues(inv ∘ sqrt, env; cutoff, ishermitian=true) for - env in envs_v1 + ITensorsExtensions.map_eigvals( + inv ∘ sqrt, env, first(inds(env)), last(inds(env)); cutoff, ishermitian=true + ) for env in envs_v1 ] inv_sqrt_envs_v2 = [ - ITensorsExtensions.map_eigenvalues(inv ∘ sqrt, env; cutoff, ishermitian=true) for - env in envs_v2 + ITensorsExtensions.map_eigvals( + inv ∘ sqrt, env, first(inds(env)), last(inds(env)); cutoff, ishermitian=true + ) for env in envs_v2 ] ψᵥ₁ᵥ₂_tn = [ψ[v⃗[1]]; ψ[v⃗[2]]; sqrt_envs_v1; sqrt_envs_v2] ψᵥ₁ᵥ₂ = contract(ψᵥ₁ᵥ₂_tn; sequence=contraction_sequence(ψᵥ₁ᵥ₂_tn; alg="optimal")) @@ -131,20 +135,24 @@ function simple_update_bp(o, ψ, v⃗; envs, (singular_values!)=nothing, apply_k envs_v1 = filter(env -> hascommoninds(env, ψ[v⃗[1]]), envs) envs_v2 = filter(env -> hascommoninds(env, ψ[v⃗[2]]), envs) sqrt_envs_v1 = [ - ITensorsExtensions.map_eigenvalues(sqrt, env; cutoff, ishermitian=true) for - env in envs_v1 + ITensorsExtensions.map_eigvals( + sqrt, env, first(inds(env)), last(inds(env)); cutoff, ishermitian=true + ) for env in envs_v1 ] sqrt_envs_v2 = [ - ITensorsExtensions.map_eigenvalues(sqrt, env; cutoff, ishermitian=true) for - env in envs_v2 + ITensorsExtensions.map_eigvals( + sqrt, env, first(inds(env)), last(inds(env)); cutoff, ishermitian=true + ) for env in envs_v2 ] inv_sqrt_envs_v1 = [ - ITensorsExtensions.map_eigenvalues(inv ∘ sqrt, env; cutoff, ishermitian=true) for - env in envs_v1 + ITensorsExtensions.map_eigvals( + inv ∘ sqrt, env, first(inds(env)), last(inds(env)); cutoff, ishermitian=true + ) for env in envs_v1 ] inv_sqrt_envs_v2 = [ - ITensorsExtensions.map_eigenvalues(inv ∘ sqrt, env; cutoff, ishermitian=true) for - env in envs_v2 + ITensorsExtensions.map_eigvals( + inv ∘ sqrt, env, first(inds(env)), last(inds(env)); cutoff, ishermitian=true + ) for env in envs_v2 ] ψᵥ₁ = contract([ψ[v⃗[1]]; sqrt_envs_v1]) ψᵥ₂ = contract([ψ[v⃗[2]]; sqrt_envs_v2]) diff --git a/test/test_itensorsextensions.jl b/test/test_itensorsextensions.jl index 8a7edf93..bf5db060 100644 --- a/test/test_itensorsextensions.jl +++ b/test/test_itensorsextensions.jl @@ -14,7 +14,7 @@ using ITensors: replaceinds, dir, array -using ITensorNetworks.ITensorsExtensions: map_eigenvalues +using ITensorNetworks.ITensorsExtensions: map_eigvals using ITensorNetworks: siteinds, random_tensornetwork using NamedGraphs: named_grid using Random @@ -26,12 +26,13 @@ Random.seed!(1234) for eltype in [Float64, ComplexF64] for n in [2, 3, 5, 10] i, j = Index(n, "i"), Index(n, "j") + linds, rinds = Index[i], Index[j] A = randn(eltype, (n, n)) A = A * A' P = ITensor(A, i, j) - sqrtP = map_eigenvalues(sqrt, P) - inv_P = dag(map_eigenvalues(inv, P)) - inv_sqrtP = dag(map_eigenvalues(inv ∘ sqrt, P)) + sqrtP = map_eigvals(sqrt, P, linds, rinds; ishermitian=true) + inv_P = dag(map_eigvals(inv, P, linds, rinds; ishermitian=true)) + inv_sqrtP = dag(map_eigvals(inv ∘ sqrt, P, linds, rinds; ishermitian=true)) sqrtPdag = replaceind(dag(sqrtP), i, i') P2 = replaceind(sqrtP * sqrtPdag, i', j)