From fff1008a1ae42575dbdc52eb1bf7733c18d68e72 Mon Sep 17 00:00:00 2001 From: Jutho Date: Thu, 4 Jan 2024 00:15:59 +0100 Subject: [PATCH] fix construct and remove braidingtensor preprocessors --- src/planar/macros.jl | 2 +- src/planar/planaroperations.jl | 2 + src/planar/preprocessors.jl | 284 +++++++++++++++++---------------- src/tensors/braidingtensor.jl | 12 -- test/planar.jl | 24 ++- 5 files changed, 167 insertions(+), 157 deletions(-) diff --git a/src/planar/macros.jl b/src/planar/macros.jl index e10308d2..2682d864 100644 --- a/src/planar/macros.jl +++ b/src/planar/macros.jl @@ -102,7 +102,7 @@ function _plansor(expr, kwargs...) tparser = TO.tensorparser(expr, kwargs...) pparser = planarparser(expr, kwargs...) - insert!(tparser.preprocessors, 5, _remove_braidingtensors) + insert!(tparser.preprocessors, 4, _remove_braidingtensors) tensorex = tparser(expr) planarex = pparser(expr) diff --git a/src/planar/planaroperations.jl b/src/planar/planaroperations.jl index ee23bcd7..1d04b764 100644 --- a/src/planar/planaroperations.jl +++ b/src/planar/planaroperations.jl @@ -61,6 +61,7 @@ function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, codB, domB = codomainind(B), domainind(B) oindA, cindA = pA cindB, oindB = pB + # @show codA, domA, codB, domB, oindA, cindA, oindB, cindB, pAB oindA, cindA, oindB, cindB = reorder_indices(codA, domA, codB, domB, oindA, cindA, oindB, cindB, pAB...) @@ -91,6 +92,7 @@ _cyclicpermute(t::Tuple{}) = () function reorder_indices(codA, domA, codB, domB, oindA, oindB, p1, p2) N₁ = length(oindA) N₂ = length(oindB) + # @show codA, domA, codB, domB, oindA, oindB, p1, p2 @assert length(p1) == N₁ && all(in(p1), 1:N₁) @assert length(p2) == N₂ && all(in(p2), N₁ .+ (1:N₂)) oindA2 = TupleTools.getindices(oindA, p1) diff --git a/src/planar/preprocessors.jl b/src/planar/preprocessors.jl index 9fbd51a5..00b8ec5b 100644 --- a/src/planar/preprocessors.jl +++ b/src/planar/preprocessors.jl @@ -71,8 +71,11 @@ _add_adjoint(ex) = Expr(TO.prime, ex) # used by `@planar`: identify braiding tensors (corresponding to name τ) and discover their # spaces from the rest of the expression. Construct the explicit BraidingTensor objects and # insert them in the expression. -function _construct_braidingtensors(ex::Expr) - if TO.isdefinition(ex) || TO.isassignment(ex) +function _construct_braidingtensors(ex) + ex isa Expr || return ex + if ex.head == :macrocall && ex.args[1] == Symbol("@notensor") + return ex + elseif TO.isdefinition(ex) || TO.isassignment(ex) lhs, rhs = TO.getlhs(ex), TO.getrhs(ex) if !TO.istensorexpr(rhs) return ex @@ -82,10 +85,11 @@ function _construct_braidingtensors(ex::Expr) if TO.isassignment(ex) && TO.istensor(lhs) obj, leftind, rightind = TO.decomposetensor(lhs) for (i, l) in enumerate(leftind) - indexmap[l] = Expr(:call, :space, _add_adjoint(obj), i) + indexmap[l] = Expr(:call, :dual, Expr(:call, :space, obj, i)) end for (i, l) in enumerate(rightind) - indexmap[l] = Expr(:call, :space, _add_adjoint(obj), length(leftind) + i) + indexmap[l] = Expr(:call, :dual, + Expr(:call, :space, obj, length(leftind) + i)) end end newrhs, success = _construct_braidingtensors!(rhs, preargs, indexmap) @@ -107,11 +111,11 @@ function _construct_braidingtensors(ex::Expr) return Expr(ex.head, map(_construct_braidingtensors, ex.args)...) end end -_construct_braidingtensors(x) = x function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed to be a single tensor expression if TO.isscalarexpr(ex) - return ex, true + # ex could be tensorscalar call with more braiding tensors + return _construct_braidingtensors(ex), true elseif TO.istensor(ex) obj, leftind, rightind = TO.decomposetensor(ex) if _remove_adjoint(obj) == :τ @@ -156,15 +160,17 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t newex = ex success = true end - # add spaces of the tensor to the indexmap - for (i, l) in enumerate(leftind) - if !haskey(indexmap, l) - indexmap[l] = Expr(:call, :space, obj, i) + if success == true + # add spaces of the tensor to the indexmap + for (i, l) in enumerate(leftind) + if !haskey(indexmap, l) + indexmap[l] = Expr(:call, :space, obj, i) + end end - end - for (i, l) in enumerate(rightind) - if !haskey(indexmap, l) - indexmap[l] = Expr(:call, :space, obj, length(leftind) + i) + for (i, l) in enumerate(rightind) + if !haskey(indexmap, l) + indexmap[l] = Expr(:call, :space, obj, length(leftind) + i) + end end end return newex, success @@ -232,132 +238,144 @@ end # used by non-planar parser of `@plansor`: remove explicit braiding tensors function _remove_braidingtensors(ex) ex isa Expr || return ex - outgoing = [] - - if TO.isdefinition(ex) || TO.isassignment(ex) + if ex.head == :macrocall && ex.args[1] == Symbol("@notensor") + return ex + elseif TO.isdefinition(ex) || TO.isassignment(ex) lhs, rhs = TO.getlhs(ex), TO.getrhs(ex) - if TO.istensorexpr(rhs) - list = TO.gettensors(_conj_to_adjoint(rhs)) - if TO.istensor(lhs) - obj, l, r = TO.decomposetensor(lhs) - outgoing = [l; r] - end - else + if !TO.istensorexpr(rhs) return ex end + indexmap = Dict{Any,Any}() + if TO.istensor(lhs) + obj, leftind, rightind = TO.decomposetensor(lhs) + end + newrhs, unchanged = _remove_braidingtensors!(rhs, indexmap) + isempty(indexmap) || + throw(ArgumentError("cannot determine the spaces of all braiding tensors in $ex")) + return Expr(ex.head, lhs, newrhs) elseif TO.istensorexpr(ex) - list = TO.gettensors(_conj_to_adjoint(ex)) + indexmap = Dict{Any,Any}() + + newex, unchanged = _remove_braidingtensors!(ex, indexmap) + isempty(indexmap) || + throw(ArgumentError("cannot determine the spaces of all braiding tensors in $ex")) + return newex else return Expr(ex.head, map(_remove_braidingtensors, ex.args)...) end +end - τs = Any[] - i = 1 - while i <= length(list) - t = list[i] - if _remove_adjoint(TO.gettensorobject(t)) == :τ - push!(τs, t) - deleteat!(list, i) - else - i += 1 - end - end +function _remove_braidingtensors!(ex, indexmap) # ex is guaranteed to be a single tensor expression + if TO.isscalarexpr(ex) + return _remove_braidingtensors(ex), true + elseif TO.istensor(ex) + obj, leftind, rightind = TO.decomposetensor(ex) + if _remove_adjoint(obj) == :τ + # remove braiding tensor and add labels to indexmap + length(leftind) == length(rightind) == 2 || + throw(ArgumentError("The name τ is reserved for the braiding, and should have two input and two output indices.")) - indexmap = Dict{Any,Any}() - # to remove the braidingtensors, we need to map certain indices to other indices - for t in τs - obj, leftind, rightind = TO.decomposetensor(t) - length(leftind) == length(rightind) == 2 || - throw(ArgumentError("The name τ is reserved for the braiding, and should have two input and two output indices.")) - if _is_adjoint(obj) i1b, i2b, = leftind i2a, i1a, = rightind - else - i2b, i1b, = leftind - i1a, i2a, = rightind - end - - i1a = get(indexmap, i1a, i1a) - i1b = get(indexmap, i1b, i1b) - i2a = get(indexmap, i2a, i2a) - i2b = get(indexmap, i2b, i2b) - - obj_and_pos1a = _findindex(i1a, list) - obj_and_pos2a = _findindex(i2a, list) - obj_and_pos1b = _findindex(i1b, list) - obj_and_pos2b = _findindex(i2b, list) - - if i1a in outgoing - indexmap[i1a] = i1a - indexmap[i1b] = i1a - elseif i1b in outgoing - indexmap[i1a] = i1b - indexmap[i1b] = i1b - else - if i1a isa Int && i1b isa Int - indexmap[i1a] = max(i1a, i1b) - indexmap[i1b] = max(i1a, i1b) + if i1a == i1b || (haskey(indexmap, i1a) && haskey(indexmap, i1b)) + throw(IndexError("Cannot resolve indices $i1a and $i1b that occur only on braidings.")) + elseif haskey(indexmap, i1a) + i1c = indexmap[i1a] + indexmap[i1c] = i1b + indexmap[i1b] = i1c + delete!(indexmap, i1a) + elseif haskey(indexmap, i1b) + i1c = indexmap[i1b] + indexmap[i1c] = i1a + indexmap[i1a] = i1c + delete!(indexmap, i1b) else - indexmap[i1a] = i1a + indexmap[i1a] = i1b indexmap[i1b] = i1a end - end - - if i2a in outgoing - indexmap[i2a] = i2a - indexmap[i2b] = i2a - elseif i2b in outgoing - indexmap[i2a] = i2b - indexmap[i2b] = i2b - else - if i2a isa Int && i2b isa Int - indexmap[i2a] = max(i2a, i2b) - indexmap[i2b] = max(i2a, i2b) + if i2a == i2b || (haskey(indexmap, i2a) && haskey(indexmap, i2b)) + throw(IndexError("Cannot resolve indices $i2a and $i2b that occur only on braidings.")) + elseif haskey(indexmap, i2a) + i2c = indexmap[i2a] + indexmap[i2c] = i2b + indexmap[i2b] = i2c + delete!(indexmap, i2a) + elseif haskey(indexmap, i2b) + i2c = indexmap[i2b] + indexmap[i2c] = i2a + indexmap[i2a] = i2c + delete!(indexmap, i2b) else - indexmap[i2a] = i2a + indexmap[i2a] = i2b indexmap[i2b] = i2a end + return One(), false # when there are still braiding tensors, we haven't finished + else + unchanged = true + for (i, l) in enumerate(leftind) + if haskey(indexmap, l) + unchanged = false + l′ = indexmap[l] + leftind[i] = l′ + delete!(indexmap, l) + delete!(indexmap, l′) + end + end + for (i, l) in enumerate(rightind) + if haskey(indexmap, l) + unchanged = false + l′ = indexmap[l] + rightind[i] = l′ + delete!(indexmap, l) + delete!(indexmap, l′) + end + end + return Expr(:typed_vcat, obj, Expr(:tuple, leftind...), + Expr(:tuple, rightind...)), unchanged end - end - - # simple loop that tries to simplify the indicemaps (a=>b,b=>c -> a=>c,b=>c) - changed = true - while changed == true - changed = false - i = 1 - for (k, v) in indexmap - if v in keys(indexmap) && indexmap[v] != v - changed = true - indexmap[k] = indexmap[v] + elseif TO.isgeneraltensor(ex) + args = ex.args + newargs = Vector{Any}(undef, length(args)) + unchanged = true + for i in 1:length(ex.args) + newargs[i], unchangeda = _remove_braidingtensors!(args[i], indexmap) + unchanged = unchanged && unchangeda + end + newex = Expr(ex.head, newargs...) + return newex, unchanged + elseif isexpr(ex, :call) && ex.args[1] == :* + args = ex.args + newargs = copy(args) + unchanged = map(i -> false, args) + unchanged[1] = true + for i in 2:length(ex.args) + newargs[i], unchanged[i] = _remove_braidingtensors!(newargs[i], indexmap) + end + all(unchanged) && return ex, true + while !all(unchanged) + for i in 2:length(ex.args) + newargs[i], unchanged[i] = _remove_braidingtensors!(newargs[i], indexmap) end end - end - - ex = TO.replaceindices(i -> get(indexmap, i, i), ex) - return _purge_braidingtensors(ex) -end - -function _purge_braidingtensors(ex) # actually remove the braidingtensors - ex isa Expr || return ex - args = collect(filter(ex.args) do a - if isexpr(a, :call) && a.args[1] == :conj - a = a.args[2] - end - if a isa Expr && TO.istensor(a) && - _remove_adjoint(TO.gettensorobject(a)) == :τ - _, leftind, rightind = TO.decomposetensor(a) - (leftind[1] == rightind[2] && leftind[2] == rightind[1]) || - throw(ArgumentError("unable to remove braiding tensor $a")) - return false - end - return true - end) - - # multiplication with only a single argument is (rightfully) seen as invalid syntax - if isexpr(ex, :call) && args[1] == :* && length(args) == 2 - return _purge_braidingtensors(args[2]) + return Expr(ex.head, newargs...), false + elseif isexpr(ex, :call) && ex.args[1] ∈ (:+, :-) + newargs = copy(ex.args) + indexmaps = [copy(indexmap) for _ in 1:(length(newargs) - 1)] + unchanged = true + for i in 2:length(ex.args) + newargs[i], unchangeda = _remove_braidingtensors!(ex.args[i], indexmaps[i - 1]) + unchanged = unchanged && unchangeda + end + newex = Expr(ex.head, newargs...) + return newex, unchanged + elseif isexpr(ex, :call) && ex.args[1] == :/ && length(ex.args) == 3 + newarg, unchanged = _remove_braidingtensors!(ex.args[2], indexmap) + return Expr(:call, :/, newarg, ex.args[3]), unchanged + elseif isexpr(ex, :call) && ex.args[1] == :\ && length(ex.args) == 3 + newarg, unchanged = _remove_braidingtensors!(ex.args[3], indexmap) + return Expr(:call, :\, ex.args[2], newarg), unchanged else - return Expr(ex.head, map(_purge_braidingtensors, args)...) + error("unexpected expression $ex") end end @@ -396,24 +414,15 @@ function _decompose_planar_contractions(ex::Expr, temporaries) end if TO.isassignment(ex) || TO.isdefinition(ex) lhs, rhs = TO.getlhs(ex), TO.getrhs(ex) - if TO.istensorexpr(rhs) - pre = Vector{Any}() - if TO.istensor(lhs) - rhs = _extract_contraction_pairs(rhs, lhs, pre, temporaries) - return Expr(:block, pre..., Expr(ex.head, lhs, rhs)) - else - lhssym = gensym(string(lhs)) - lhstensor = Expr(:typed_vcat, lhssym, Expr(:tuple), Expr(:tuple)) - rhs = _extract_contraction_pairs(rhs, lhstensor, pre, temporaries) - push!(temporaries, lhssym) - return Expr(:block, pre..., Expr(:(:=), lhstensor, rhs), - Expr(:(=), lhs, lhstensor)) - end + pre = Vector{Any}() + if TO.istensor(lhs) + rhs = _extract_contraction_pairs(rhs, lhs, pre, temporaries) else - return ex + rhs = _extract_contraction_pairs(rhs, (Any[], Any[]), pre, temporaries) end + return Expr(:block, pre..., Expr(ex.head, lhs, rhs)) end - if TO.istensorexpr(ex) + if TO.istensorexpr(ex) || (isexpr(ex, :call) && ex.args[1] == :tensorscalar) pre = Vector{Any}() rhs = _extract_contraction_pairs(ex, (Any[], Any[]), pre, temporaries) return Expr(:block, pre..., rhs) @@ -429,7 +438,10 @@ end # if lhs is an expression, it contains the existing lhs and thus the index order # if lhs is a tuple, the result is a temporary object and the tuple (lind, rind) gives a suggestion for the preferred index order function _extract_contraction_pairs(rhs, lhs, pre, temporaries) - if TO.isscalarexpr(rhs) + if isexpr(rhs, :call) && rhs.args[1] == :tensorscalar + newarg = _extract_contraction_pairs(rhs.args[2], lhs, pre, temporaries) + return Expr(:call, :tensorscalar, newarg) + elseif TO.isscalarexpr(rhs) return rhs elseif TO.isgeneraltensor(rhs) if TO.hastraceindices(rhs) && lhs isa Tuple @@ -479,7 +491,6 @@ function _extract_contraction_pairs(rhs, lhs, pre, temporaries) a2 = _extract_contraction_pairs(rhs.args[3], (cind2, reverse(oind2)), pre, temporaries) end - # @show a1, a2, oind1, oind2 if TO.isscalarexpr(a1) || TO.isscalarexpr(a2) rhs = Expr(:call, :*, a1, a2) @@ -499,7 +510,6 @@ function _extract_contraction_pairs(rhs, lhs, pre, temporaries) ind1, ind2 = ind2, ind1 oind1, oind2 = oind2, oind1 end - # @show a1, a2, oind1, oind2 if lhs isa Tuple rhs = Expr(:call, :*, a1, a2) s = gensym() diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index ab945e81..95028587 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -271,18 +271,6 @@ function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, end return C end -function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, - A::BraidingTensor{S}, - (oindA, cindA)::Index2Tuple{2,2}, - B::BraidingTensor{S}, - (cindB, oindB)::Index2Tuple{2,2}, - (p1, p2)::Index2Tuple{N₁,N₂}, - α::Number, β::Number, - backend::Backend...) where {S,N₁,N₂} - return planarcontract!(C, copy(A), (oindA, cindA), B, (cindB, oindB), (p1, p2), α, β, - backend...) -end - function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, A::AbstractTensorMap{S}, (oindA, cindA)::Index2Tuple{N₃,2}, diff --git a/test/planar.jl b/test/planar.jl index 5cbe1689..d7edf146 100644 --- a/test/planar.jl +++ b/test/planar.jl @@ -196,13 +196,23 @@ end t1 = TensorMap(rand, T, V1 ← V2) t2 = TensorMap(rand, T, V2 ← V1) - tr1 = @planar opt = true t1[a; b] * t2[b; a]/2 - tr2 = @planar opt = true t1[d; a] * t2[b; c] * 1/2 * τ[c b; a d] - tr3 = @planar opt = true t1[d; a] * t2[b; c] * τ[a c; d b] /2 - tr4 = @planar opt = true t1[f; a] * 1/2 * t2[c; d] * τ[d b; c e] * τ[e b; a f] - tr5 = @planar opt = true t1[f; a] * t2[c; d]/2 * τ[d b; c e] * τ[a e; f b] - tr6 = @planar opt = true t1[f; a] * t2[c; d] * τ[c d; e b]/2 * τ[e b; a f] - tr7 = @planar opt = true t1[f; a] * t2[c; d] * (τ[c d; e b] * τ[a e; f b] /2) + tr1 = @planar opt = true t1[a; b] * t2[b; a] / 2 + tr2 = @planar opt = true t1[d; a] * t2[b; c] * 1 / 2 * τ[c b; a d] + tr3 = @planar opt = true t1[d; a] * t2[b; c] * τ[a c; d b] / 2 + tr4 = @planar opt = true t1[f; a] * 1 / 2 * t2[c; d] * τ[d b; c e] * τ[e b; a f] + tr5 = @planar opt = true t1[f; a] * t2[c; d] / 2 * τ[d b; c e] * τ[a e; f b] + tr6 = @planar opt = true t1[f; a] * t2[c; d] * τ[c d; e b] / 2 * τ[e b; a f] + tr7 = @planar opt = true t1[f; a] * t2[c; d] * (τ[c d; e b] * τ[a e; f b] / 2) + + @test tr1 ≈ tr2 ≈ tr3 ≈ tr4 ≈ tr5 ≈ tr6 ≈ tr7 + + tr1 = @plansor opt = true t1[a; b] * t2[b; a] / 2 + tr2 = @plansor opt = true t1[d; a] * t2[b; c] * 1 / 2 * τ[c b; a d] + tr3 = @plansor opt = true t1[d; a] * t2[b; c] * τ[a c; d b] / 2 + tr4 = @plansor opt = true t1[f; a] * 1 / 2 * t2[c; d] * τ[d b; c e] * τ[e b; a f] + tr5 = @plansor opt = true t1[f; a] * t2[c; d] / 2 * τ[d b; c e] * τ[a e; f b] + tr6 = @plansor opt = true t1[f; a] * t2[c; d] * τ[c d; e b] / 2 * τ[e b; a f] + tr7 = @plansor opt = true t1[f; a] * t2[c; d] * (τ[c d; e b] * τ[a e; f b] / 2) @test tr1 ≈ tr2 ≈ tr3 ≈ tr4 ≈ tr5 ≈ tr6 ≈ tr7 end