diff --git a/src/planar/macros.jl b/src/planar/macros.jl index 1d76cc48..e10308d2 100644 --- a/src/planar/macros.jl +++ b/src/planar/macros.jl @@ -23,7 +23,7 @@ function planarparser(planarexpr, kwargs...) # braiding tensors need to be instantiated before kwargs are processed push!(parser.preprocessors, _construct_braidingtensors) - + for (name, val) in kwargs if name == :order isexpr(val, :tuple) || @@ -62,7 +62,7 @@ function planarparser(planarexpr, kwargs...) throw(ArgumentError("Unknown keyword argument `name`.")) end end - + treebuilder = parser.contractiontreebuilder treesorter = parser.contractiontreesorter costcheck = parser.contractioncostcheck diff --git a/src/planar/preprocessors.jl b/src/planar/preprocessors.jl index 6c284ad9..2381c1f6 100644 --- a/src/planar/preprocessors.jl +++ b/src/planar/preprocessors.jl @@ -70,6 +70,7 @@ _add_adjoint(ex) = Expr(TO.prime, ex) # used by `@planar`: realize explicit braiding tensors function _construct_braidingtensors(ex::Expr) + # create a list of all tensors if TO.isdefinition(ex) || TO.isassignment(ex) lhs, rhs = TO.getlhs(ex), TO.getrhs(ex) if TO.istensorexpr(rhs) @@ -89,11 +90,17 @@ function _construct_braidingtensors(ex::Expr) return Expr(ex.head, map(_construct_braidingtensors, ex.args)...) end + # create a mapping of tensor_expr => braiding_constructor_expr i = 1 translatebraidings = Dict{Any,Any}() while i <= length(list) t = list[i] if _remove_adjoint(TO.gettensorobject(t)) == :τ + # verify that the braiding tensor has the correct form + obj, leftind, rightind = TO.decomposetensor(t) + length(leftind) == length(rightind) == 2 || + throw(ArgumentError("Invalid expression $t. τ is reserved for the braiding, and should have two input and two output indices.")) + translatebraidings[t] = Expr(:call, GlobalRef(TensorKit, :BraidingTensor)) deleteat!(list, i) else @@ -101,93 +108,68 @@ function _construct_braidingtensors(ex::Expr) end end - unresolved = Any[] # list of indices that we couldn't yet figure out - indexmap = Dict{Any,Any}() - # indexmap[i] contains the expression to resolve the space for index i - for (t, construct_expr) in translatebraidings - obj, leftind, rightind = TO.decomposetensor(t) - length(leftind) == length(rightind) == 2 || - throw(ArgumentError("The name τ is reserved for the braiding, and should have two input and two output indices.")) - if _is_adjoint(obj) - i1b, i2b, = leftind - i2a, i1a, = rightind - else - i2b, i1b, = leftind - i1a, i2a, = rightind - end - - obj_and_pos1a = _findindex(i1a, list) - obj_and_pos2a = _findindex(i2a, list) - obj_and_pos1b = _findindex(i1b, list) - obj_and_pos2b = _findindex(i2b, list) - - if !isnothing(obj_and_pos1a) - indexmap[i1b] = Expr(:call, :space, obj_and_pos1a...) - indexmap[i1a] = Expr(:call, :space, obj_and_pos1a...) - elseif !isnothing(obj_and_pos1b) - indexmap[i1b] = Expr(TO.prime, Expr(:call, :space, obj_and_pos1b...)) - indexmap[i1a] = Expr(TO.prime, Expr(:call, :space, obj_and_pos1b...)) - else - push!(unresolved, (i1a, i1b)) - end + # loop over the braiding tensors and try to figure out the correct spaces + ischanged = true + pre = Expr(:block) + while ischanged + ischanged = false + for (t, construct_expr) in translatebraidings + obj, leftind, rightind = TO.decomposetensor(t) + if _is_adjoint(obj) + i1b, i2b, = leftind + i2a, i1a, = rightind + else + i2b, i1b, = leftind + i1a, i2a, = rightind + end - if !isnothing(obj_and_pos2a) - indexmap[i2b] = Expr(:call, :space, obj_and_pos2a...) - indexmap[i2a] = Expr(:call, :space, obj_and_pos2a...) - elseif !isnothing(obj_and_pos2b) - indexmap[i2b] = Expr(TO.prime, Expr(:call, :space, obj_and_pos2b...)) - indexmap[i2a] = Expr(TO.prime, Expr(:call, :space, obj_and_pos2b...)) - else - push!(unresolved, (i2a, i2b)) - end - end - # simple loop that tries to resolve as many indices as possible - changed = true - while changed == true - changed = false - i = 1 - while i <= length(unresolved) - (i1, i2) = unresolved[i] - if i1 in keys(indexmap) - changed = true - indexmap[i2] = indexmap[i1] - deleteat!(unresolved, i) - elseif i2 in keys(indexmap) - changed = true - indexmap[i1] = indexmap[i2] - deleteat!(unresolved, i) + # attempt to deduce first space + obj_and_pos1a = _findindex(i1a, list) + obj_and_pos1b = _findindex(i1b, list) + if !isnothing(obj_and_pos1a) + V_i1b = Expr(:call, :space, obj_and_pos1a...) + elseif !isnothing(obj_and_pos1b) + V_i1b = Expr(TO.prime, Expr(:call, :space, obj_and_pos1b...)) else - i += 1 + continue end - end - end - !isempty(unresolved) && - throw(ArgumentError("cannot determine the spaces of indices " * - string(tuple(unresolved...)) * - "for the braiding tensors in $(ex)")) - pre = Expr(:block) - for (t, construct_expr) in translatebraidings - obj, leftind, rightind = TO.decomposetensor(t) - if _is_adjoint(obj) - i1b, i2b, = leftind - i2a, i1a, = rightind - else - i2b, i1b, = leftind - i1a, i2a, = rightind - end - push!(construct_expr.args, indexmap[i1b]) - push!(construct_expr.args, indexmap[i2b]) - s = gensym(:τ) - push!(pre.args, :(@notensor $s = $construct_expr)) - ex = TO.replacetensorobjects(ex) do o, l, r - if o == obj && l == leftind && r == rightind - return obj == :τ ? s : Expr(TO.prime, s) + # attempt to deduce second space + obj_and_pos2a = _findindex(i2a, list) + obj_and_pos2b = _findindex(i2b, list) + if !isnothing(obj_and_pos2a) + V_i2b = Expr(:call, :space, obj_and_pos2a...) + elseif !isnothing(obj_and_pos2b) + V_i2b = Expr(TO.prime, Expr(:call, :space, obj_and_pos2b...)) else - return o + continue + end + + # both spaces deduced, construct braidingtensor + push!(construct_expr.args, V_i1b) + push!(construct_expr.args, V_i2b) + s = gensym(:τ) + push!(pre.args, :(@notensor $s = $construct_expr)) + + # and insert into tensor expression and tensorlist + ex = TO.replacetensorobjects(ex) do o, l, r + if o == obj && l == leftind && r == rightind + return obj == :τ ? s : Expr(TO.prime, s) + else + return o + end end + push!(list, + Expr(:typed_vcat, s, Expr(:tuple, leftind...), + Expr(:tuple, rightind...))) + + delete!(translatebraidings, t) + ischanged = true end end + + @assert isempty(translatebraidings) "could not figure out all spaces \n $ex" + return Expr(:block, pre, ex) end _construct_braidingtensors(x) = x diff --git a/test/planar.jl b/test/planar.jl index 062e41b6..755f6a46 100644 --- a/test/planar.jl +++ b/test/planar.jl @@ -71,29 +71,29 @@ end @testset "@planar" verbose = true begin T = ComplexF64 - + @testset "contractcheck" begin V = ℂ^2 A = TensorMap(rand, T, V ⊗ V ← V) B = TensorMap(rand, T, V ⊗ V ← V') @tensor C1[i j; k l] := A[i j; m] * B[k l; m] - @tensor contractcheck=true C2[i j; k l] := A[i j; m] * B[k l; m] + @tensor contractcheck = true C2[i j; k l] := A[i j; m] * B[k l; m] @test C1 ≈ C2 B2 = TensorMap(rand, T, V ⊗ V ← V) # wrong duality for third space @test_throws SpaceMismatch("incompatible spaces for m: $V ≠ $(V')") begin @tensor contractcheck = true C3[i j; k l] := A[i j; m] * B2[k l; m] end - + A = TensorMap(rand, T, V ← V ⊗ V) B = TensorMap(rand, T, V ⊗ V ← V) @planar C1[i; j] := A[i; k l] * τ[k l; m n] * B[m n; j] - @planar contractcheck=true C2[i; j] := A[i; k l] * τ[k l; m n] * B[m n; j] + @planar contractcheck = true C2[i; j] := A[i; k l] * τ[k l; m n] * B[m n; j] @test C1 ≈ C2 @test_throws SpaceMismatch("incompatible spaces for l: $V ≠ $(V')") begin - @planar contractcheck=true C3[i; j] := A[i; k l] * τ[k l; m n] * B[n j; m] + @planar contractcheck = true C3[i; j] := A[i; k l] * τ[k l; m n] * B[n j; m] end end - + @testset "MPS networks" begin P = ℂ^2 Vmps = ℂ^12 @@ -140,6 +140,23 @@ end conj(x′[1 3; -1]) @test force_planar(ρ2) ≈ ρ2′ @test ρ2 ≈ ρ3 + + # Periodic boundary conditions + # ---------------------------- + f1 = isomorphism(storagetype(O), fuse(Vmpo^3), Vmpo ⊗ Vmpo' ⊗ Vmpo) + f2 = isomorphism(storagetype(O), fuse(Vmpo^3), Vmpo ⊗ Vmpo' ⊗ Vmpo) + f1′ = force_planar(f1) + f2′ = force_planar(f2) + @tensor O_periodic1[-1 -2; -3 -4] := O[1 -2; -3 2] * f1[-1; 1 3 4] * + conj(f2[-4; 2 3 4]) + @plansor O_periodic2[-1 -2; -3 -4] := O[1 2; -3 6] * f1[-1; 1 3 5] * + conj(f2[-4; 6 7 8]) * τ[2 3; 7 4] * + τ[4 5; 8 -2] + @planar O_periodic′[-1 -2; -3 -4] := O′[1 2; -3 6] * f1′[-1; 1 3 5] * + conj(f2′[-4; 6 7 8]) * τ[2 3; 7 4] * + τ[4 5; 8 -2] + @test O_periodic1 ≈ O_periodic2 + @test force_planar(O_periodic1) ≈ O_periodic′ end @testset "MERA networks" begin @@ -171,4 +188,22 @@ end end @test C ≈ C′ end + + @testset "Issue 93" begin + T = Float64 + V1 = ℂ^2 + V2 = ℂ^3 + t1 = TensorMap(rand, T, V1 ← V2) + t2 = TensorMap(rand, T, V2 ← V1) + + tr1 = @planar opt = true t1[a; b] * t2[b; a] + tr2 = @planar opt = true t1[d; a] * t2[b; c] * τ[c b; a d] + tr3 = @planar opt = true t1[d; a] * t2[b; c] * τ[a c; d b] + tr4 = @planar opt = true t1[f; a] * t2[c; d] * τ[d b; c e] * τ[e b; a f] + tr5 = @planar opt = true t1[f; a] * t2[c; d] * τ[d b; c e] * τ[a e; f b] + tr6 = @planar opt = true t1[f; a] * t2[c; d] * τ[c d; e b] * τ[e b; a f] + tr7 = @planar opt = true t1[f; a] * t2[c; d] * τ[c d; e b] * τ[a e; f b] + + @test tr1 ≈ tr2 ≈ tr3 ≈ tr4 ≈ tr5 ≈ tr6 ≈ tr7 + end end