From 596b038ae9f5e3d2261f992e9397f0c494b54abb Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Fri, 1 Dec 2023 15:52:57 +0100 Subject: [PATCH] rewrite constructbraidingtensor preprocessor --- src/planar/macros.jl | 4 +- src/planar/preprocessors.jl | 234 +++++++++++++++++++--------------- src/tensors/braidingtensor.jl | 4 + test/planar.jl | 49 ++++++- 4 files changed, 180 insertions(+), 111 deletions(-) diff --git a/src/planar/macros.jl b/src/planar/macros.jl index 1d76cc48..e10308d2 100644 --- a/src/planar/macros.jl +++ b/src/planar/macros.jl @@ -23,7 +23,7 @@ function planarparser(planarexpr, kwargs...) # braiding tensors need to be instantiated before kwargs are processed push!(parser.preprocessors, _construct_braidingtensors) - + for (name, val) in kwargs if name == :order isexpr(val, :tuple) || @@ -62,7 +62,7 @@ function planarparser(planarexpr, kwargs...) throw(ArgumentError("Unknown keyword argument `name`.")) end end - + treebuilder = parser.contractiontreebuilder treesorter = parser.contractiontreesorter costcheck = parser.contractioncostcheck diff --git a/src/planar/preprocessors.jl b/src/planar/preprocessors.jl index 6c284ad9..3d6abecc 100644 --- a/src/planar/preprocessors.jl +++ b/src/planar/preprocessors.jl @@ -68,129 +68,159 @@ _is_adjoint(ex) = isexpr(ex, TO.prime) _remove_adjoint(ex) = _is_adjoint(ex) ? ex.args[1] : ex _add_adjoint(ex) = Expr(TO.prime, ex) -# used by `@planar`: realize explicit braiding tensors +# used by `@planar`: identify braiding tensors (corresponding to name τ) and discover their +# spaces from the rest of the expression. Construct the explicit BraidingTensor objects and +# insert them in the expression. function _construct_braidingtensors(ex::Expr) if TO.isdefinition(ex) || TO.isassignment(ex) lhs, rhs = TO.getlhs(ex), TO.getrhs(ex) - if TO.istensorexpr(rhs) - list = TO.gettensors(_conj_to_adjoint(rhs)) - if TO.isassignment(ex) && TO.istensor(lhs) - obj, l, r = TO.decomposetensor(lhs) - lhs_as_rhs = Expr(:typed_vcat, _add_adjoint(obj), - Expr(:tuple, r...), Expr(:tuple, l...)) - push!(list, lhs_as_rhs) - end - else + if !TO.istensorexpr(rhs) return ex end + preargs = Vector{Any}() + indexmap = Dict{Any,Any}() + if TO.isassignment(ex) && TO.istensor(lhs) + obj, leftind, rightind = TO.decomposetensor(lhs) + for (i, l) in enumerate(leftind) + indexmap[l] = Expr(:call, :space, _add_adjoint(obj), i) + end + for (i, l) in enumerate(rightind) + indexmap[l] = Expr(:call, :space, _add_adjoint(obj), length(leftind) + i) + end + end + newrhs, success = _construct_braidingtensors!(rhs, preargs, indexmap) + success || + throw(ArgumentError("cannot determine the spaces of all braiding tensors in $ex")) + pre = Expr(:macrocall, Symbol("@notensor"), + LineNumberNode(@__LINE__, Symbol(@__FILE__)), Expr(:block, preargs...)) + return Expr(:block, pre, Expr(ex.head, lhs, newrhs)) elseif TO.istensorexpr(ex) - list = TO.gettensors(_conj_to_adjoint(ex)) + preargs = Vector{Any}() + indexmap = Dict{Any,Any}() + newex, success = _construct_braidingtensors!(ex, preargs, indexmap) + success || + throw(ArgumentError("cannot determine the spaces of all braiding tensors in $ex")) + pre = Expr(:macrocall, Symbol("@notensor"), + LineNumberNode(@__LINE__, Symbol(@__FILE__)), Expr(:block, preargs...)) + return Expr(:block, pre, newex) else return Expr(ex.head, map(_construct_braidingtensors, ex.args)...) end +end +_construct_braidingtensors(x) = x - i = 1 - translatebraidings = Dict{Any,Any}() - while i <= length(list) - t = list[i] - if _remove_adjoint(TO.gettensorobject(t)) == :τ - translatebraidings[t] = Expr(:call, GlobalRef(TensorKit, :BraidingTensor)) - deleteat!(list, i) - else - i += 1 - end - end - - unresolved = Any[] # list of indices that we couldn't yet figure out - indexmap = Dict{Any,Any}() - # indexmap[i] contains the expression to resolve the space for index i - for (t, construct_expr) in translatebraidings - obj, leftind, rightind = TO.decomposetensor(t) - length(leftind) == length(rightind) == 2 || - throw(ArgumentError("The name τ is reserved for the braiding, and should have two input and two output indices.")) - if _is_adjoint(obj) - i1b, i2b, = leftind - i2a, i1a, = rightind - else - i2b, i1b, = leftind - i1a, i2a, = rightind - end - - obj_and_pos1a = _findindex(i1a, list) - obj_and_pos2a = _findindex(i2a, list) - obj_and_pos1b = _findindex(i1b, list) - obj_and_pos2b = _findindex(i2b, list) +function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed to be a single tensor expression + if TO.istensor(ex) + obj, leftind, rightind = TO.decomposetensor(ex) + if _remove_adjoint(obj) == :τ + # try to construct a braiding tensor + length(leftind) == length(rightind) == 2 || + throw(ArgumentError("The name τ is reserved for the braiding, and should have two input and two output indices.")) + if _is_adjoint(obj) + i1b, i2b, = leftind + i2a, i1a, = rightind + else + i2b, i1b, = leftind + i1a, i2a, = rightind + end - if !isnothing(obj_and_pos1a) - indexmap[i1b] = Expr(:call, :space, obj_and_pos1a...) - indexmap[i1a] = Expr(:call, :space, obj_and_pos1a...) - elseif !isnothing(obj_and_pos1b) - indexmap[i1b] = Expr(TO.prime, Expr(:call, :space, obj_and_pos1b...)) - indexmap[i1a] = Expr(TO.prime, Expr(:call, :space, obj_and_pos1b...)) + foundV1, foundV2 = false, false + if haskey(indexmap, i1a) + V1 = indexmap[i1a] + foundV1 = true + elseif haskey(indexmap, i1b) + V1 = Expr(:call, :dual, indexmap[i1b]) + foundV1 = true + end + if haskey(indexmap, i2a) + V2 = indexmap[i2a] + foundV2 = true + elseif haskey(indexmap, i2b) + V2 = Expr(:call, :dual, indexmap[i2b]) + foundV2 = true + end + if foundV1 && foundV2 + s = gensym(:τ) + constructex = Expr(:call, GlobalRef(TensorKit, :BraidingTensor), V1, V2) + push!(preargs, Expr(:(=), s, constructex)) + obj = _is_adjoint(obj) ? _add_adjoint(s) : s + success = true + else + success = false + end + newex = Expr(:typed_vcat, obj, Expr(:tuple, leftind...), + Expr(:tuple, rightind...)) else - push!(unresolved, (i1a, i1b)) + newex = ex + success = true end - - if !isnothing(obj_and_pos2a) - indexmap[i2b] = Expr(:call, :space, obj_and_pos2a...) - indexmap[i2a] = Expr(:call, :space, obj_and_pos2a...) - elseif !isnothing(obj_and_pos2b) - indexmap[i2b] = Expr(TO.prime, Expr(:call, :space, obj_and_pos2b...)) - indexmap[i2a] = Expr(TO.prime, Expr(:call, :space, obj_and_pos2b...)) - else - push!(unresolved, (i2a, i2b)) + # add spaces of the tensor to the indexmap + for (i, l) in enumerate(leftind) + if !haskey(indexmap, l) + indexmap[l] = Expr(:call, :space, obj, i) + end end - end - # simple loop that tries to resolve as many indices as possible - changed = true - while changed == true - changed = false - i = 1 - while i <= length(unresolved) - (i1, i2) = unresolved[i] - if i1 in keys(indexmap) - changed = true - indexmap[i2] = indexmap[i1] - deleteat!(unresolved, i) - elseif i2 in keys(indexmap) - changed = true - indexmap[i1] = indexmap[i2] - deleteat!(unresolved, i) - else - i += 1 + for (i, l) in enumerate(rightind) + if !haskey(indexmap, l) + indexmap[l] = Expr(:call, :space, obj, length(leftind) + i) end end - end - !isempty(unresolved) && - throw(ArgumentError("cannot determine the spaces of indices " * - string(tuple(unresolved...)) * - "for the braiding tensors in $(ex)")) - - pre = Expr(:block) - for (t, construct_expr) in translatebraidings - obj, leftind, rightind = TO.decomposetensor(t) - if _is_adjoint(obj) - i1b, i2b, = leftind - i2a, i1a, = rightind - else - i2b, i1b, = leftind - i1a, i2a, = rightind + return newex, success + elseif TO.isgeneraltensor(ex) + args = ex.args + newargs = Vector{Any}(undef, length(args)) + success = true + for i in 1:length(ex.args) + newargs[i], successa = _construct_braidingtensors!(args[i], preargs, indexmap) + success = success && successa end - push!(construct_expr.args, indexmap[i1b]) - push!(construct_expr.args, indexmap[i2b]) - s = gensym(:τ) - push!(pre.args, :(@notensor $s = $construct_expr)) - ex = TO.replacetensorobjects(ex) do o, l, r - if o == obj && l == leftind && r == rightind - return obj == :τ ? s : Expr(TO.prime, s) - else - return o + newex = Expr(ex.head, newargs...) + return newex, success + elseif isexpr(ex, :call) && ex.args[1] == :* + args = ex.args + newargs = Vector{Any}(undef, length(args)) + newargs[1] = args[1] + successes = map(i -> false, args) + successes[1] = true + numsuccess = 1 + while !all(successes) + for i in 2:length(ex.args) + successes[i] && continue + newargs[i], successa = _construct_braidingtensors!(args[i], preargs, + indexmap) + successes[i] = successa + end + if numsuccess == count(successes) + break + end + numsuccess = count(successes) + end + success = numsuccess == length(successes) + newex = Expr(ex.head, newargs...) + return newex, success + elseif isexpr(ex, :call) && ex.args[1] ∈ (:+, :-) + args = ex.args + newargs = Vector{Any}(undef, length(args)) + newargs[1] = args[1] + success = true + indices = TO.getindices(ex) + for i in 2:length(ex.args) + indexmapa = copy(indexmap) + newargs[i], successa = _construct_braidingtensors!(args[i], preargs, indexmapa) + for l in indices[i] + if !haskey(indexmap, l) && haskey(indexmapa, l) + indexmap[l] = indexmapa[l] + end end + success = success && successa end + newex = Expr(ex.head, newargs...) + return newex, success + else + @show("huh?") + return ex, true end - return Expr(:block, pre, ex) end -_construct_braidingtensors(x) = x # used by non-planar parser of `@plansor`: remove explicit braiding tensors function _remove_braidingtensors(ex) diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index 785ebf49..2ee46f66 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -3,9 +3,13 @@ #====================================================================# """ struct BraidingTensor{S<:IndexSpace} <: AbstractTensorMap{S, 2, 2} + BraidingTensor(V1::S, V2::S, adjoint::Bool=false) where {S<:IndexSpace} Specific subtype of [`AbstractTensorMap`](@ref) for representing the braiding tensor that braids the first input over the second input; its inverse can be obtained as the adjoint. + +It holds that `domain(BraidingTensor(V1, V2)) == V1 ⊗ V2` and +`codomain(BraidingTensor(V1, V2)) == V2 ⊗ V1`. """ struct BraidingTensor{S<:IndexSpace,A} <: AbstractTensorMap{S,2,2} V1::S diff --git a/test/planar.jl b/test/planar.jl index 062e41b6..6254711d 100644 --- a/test/planar.jl +++ b/test/planar.jl @@ -71,29 +71,29 @@ end @testset "@planar" verbose = true begin T = ComplexF64 - + @testset "contractcheck" begin V = ℂ^2 A = TensorMap(rand, T, V ⊗ V ← V) B = TensorMap(rand, T, V ⊗ V ← V') @tensor C1[i j; k l] := A[i j; m] * B[k l; m] - @tensor contractcheck=true C2[i j; k l] := A[i j; m] * B[k l; m] + @tensor contractcheck = true C2[i j; k l] := A[i j; m] * B[k l; m] @test C1 ≈ C2 B2 = TensorMap(rand, T, V ⊗ V ← V) # wrong duality for third space @test_throws SpaceMismatch("incompatible spaces for m: $V ≠ $(V')") begin @tensor contractcheck = true C3[i j; k l] := A[i j; m] * B2[k l; m] end - + A = TensorMap(rand, T, V ← V ⊗ V) B = TensorMap(rand, T, V ⊗ V ← V) @planar C1[i; j] := A[i; k l] * τ[k l; m n] * B[m n; j] - @planar contractcheck=true C2[i; j] := A[i; k l] * τ[k l; m n] * B[m n; j] + @planar contractcheck = true C2[i; j] := A[i; k l] * τ[k l; m n] * B[m n; j] @test C1 ≈ C2 - @test_throws SpaceMismatch("incompatible spaces for l: $V ≠ $(V')") begin - @planar contractcheck=true C3[i; j] := A[i; k l] * τ[k l; m n] * B[n j; m] + @test_throws SpaceMismatch("incompatible spaces for m: $V ≠ $(V')") begin + @planar contractcheck = true C3[i; j] := A[i; k l] * τ[k l; m n] * B[n j; m] end end - + @testset "MPS networks" begin P = ℂ^2 Vmps = ℂ^12 @@ -140,6 +140,23 @@ end conj(x′[1 3; -1]) @test force_planar(ρ2) ≈ ρ2′ @test ρ2 ≈ ρ3 + + # Periodic boundary conditions + # ---------------------------- + f1 = isomorphism(storagetype(O), fuse(Vmpo^3), Vmpo ⊗ Vmpo' ⊗ Vmpo) + f2 = isomorphism(storagetype(O), fuse(Vmpo^3), Vmpo ⊗ Vmpo' ⊗ Vmpo) + f1′ = force_planar(f1) + f2′ = force_planar(f2) + @tensor O_periodic1[-1 -2; -3 -4] := O[1 -2; -3 2] * f1[-1; 1 3 4] * + conj(f2[-4; 2 3 4]) + @plansor O_periodic2[-1 -2; -3 -4] := O[1 2; -3 6] * f1[-1; 1 3 5] * + conj(f2[-4; 6 7 8]) * τ[2 3; 7 4] * + τ[4 5; 8 -2] + @planar O_periodic′[-1 -2; -3 -4] := O′[1 2; -3 6] * f1′[-1; 1 3 5] * + conj(f2′[-4; 6 7 8]) * τ[2 3; 7 4] * + τ[4 5; 8 -2] + @test O_periodic1 ≈ O_periodic2 + @test force_planar(O_periodic1) ≈ O_periodic′ end @testset "MERA networks" begin @@ -171,4 +188,22 @@ end end @test C ≈ C′ end + + @testset "Issue 93" begin + T = Float64 + V1 = ℂ^2 + V2 = ℂ^3 + t1 = TensorMap(rand, T, V1 ← V2) + t2 = TensorMap(rand, T, V2 ← V1) + + tr1 = @planar opt = true t1[a; b] * t2[b; a] + tr2 = @planar opt = true t1[d; a] * t2[b; c] * τ[c b; a d] + tr3 = @planar opt = true t1[d; a] * t2[b; c] * τ[a c; d b] + tr4 = @planar opt = true t1[f; a] * t2[c; d] * τ[d b; c e] * τ[e b; a f] + tr5 = @planar opt = true t1[f; a] * t2[c; d] * τ[d b; c e] * τ[a e; f b] + tr6 = @planar opt = true t1[f; a] * t2[c; d] * τ[c d; e b] * τ[e b; a f] + tr7 = @planar opt = true t1[f; a] * t2[c; d] * τ[c d; e b] * τ[a e; f b] + + @test tr1 ≈ tr2 ≈ tr3 ≈ tr4 ≈ tr5 ≈ tr6 ≈ tr7 + end end