Skip to content

Commit

Permalink
rewrite constructbraidingtensor preprocessor
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Dec 1, 2023
1 parent 6fb9d67 commit 596b038
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 111 deletions.
4 changes: 2 additions & 2 deletions src/planar/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function planarparser(planarexpr, kwargs...)

# braiding tensors need to be instantiated before kwargs are processed
push!(parser.preprocessors, _construct_braidingtensors)

for (name, val) in kwargs
if name == :order
isexpr(val, :tuple) ||
Expand Down Expand Up @@ -62,7 +62,7 @@ function planarparser(planarexpr, kwargs...)
throw(ArgumentError("Unknown keyword argument `name`."))
end
end

treebuilder = parser.contractiontreebuilder
treesorter = parser.contractiontreesorter
costcheck = parser.contractioncostcheck
Expand Down
234 changes: 132 additions & 102 deletions src/planar/preprocessors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,129 +68,159 @@ _is_adjoint(ex) = isexpr(ex, TO.prime)
_remove_adjoint(ex) = _is_adjoint(ex) ? ex.args[1] : ex
_add_adjoint(ex) = Expr(TO.prime, ex)

# used by `@planar`: realize explicit braiding tensors
# used by `@planar`: identify braiding tensors (corresponding to name τ) and discover their
# spaces from the rest of the expression. Construct the explicit BraidingTensor objects and
# insert them in the expression.
function _construct_braidingtensors(ex::Expr)
if TO.isdefinition(ex) || TO.isassignment(ex)
lhs, rhs = TO.getlhs(ex), TO.getrhs(ex)
if TO.istensorexpr(rhs)
list = TO.gettensors(_conj_to_adjoint(rhs))
if TO.isassignment(ex) && TO.istensor(lhs)
obj, l, r = TO.decomposetensor(lhs)
lhs_as_rhs = Expr(:typed_vcat, _add_adjoint(obj),
Expr(:tuple, r...), Expr(:tuple, l...))
push!(list, lhs_as_rhs)
end
else
if !TO.istensorexpr(rhs)
return ex
end
preargs = Vector{Any}()
indexmap = Dict{Any,Any}()
if TO.isassignment(ex) && TO.istensor(lhs)
obj, leftind, rightind = TO.decomposetensor(lhs)
for (i, l) in enumerate(leftind)
indexmap[l] = Expr(:call, :space, _add_adjoint(obj), i)
end
for (i, l) in enumerate(rightind)
indexmap[l] = Expr(:call, :space, _add_adjoint(obj), length(leftind) + i)
end

Check warning on line 89 in src/planar/preprocessors.jl

View check run for this annotation

Codecov / codecov/patch

src/planar/preprocessors.jl#L83-L89

Added lines #L83 - L89 were not covered by tests
end
newrhs, success = _construct_braidingtensors!(rhs, preargs, indexmap)
success ||
throw(ArgumentError("cannot determine the spaces of all braiding tensors in $ex"))
pre = Expr(:macrocall, Symbol("@notensor"),
LineNumberNode(@__LINE__, Symbol(@__FILE__)), Expr(:block, preargs...))
return Expr(:block, pre, Expr(ex.head, lhs, newrhs))
elseif TO.istensorexpr(ex)
list = TO.gettensors(_conj_to_adjoint(ex))
preargs = Vector{Any}()
indexmap = Dict{Any,Any}()
newex, success = _construct_braidingtensors!(ex, preargs, indexmap)
success ||
throw(ArgumentError("cannot determine the spaces of all braiding tensors in $ex"))
pre = Expr(:macrocall, Symbol("@notensor"),
LineNumberNode(@__LINE__, Symbol(@__FILE__)), Expr(:block, preargs...))
return Expr(:block, pre, newex)
else
return Expr(ex.head, map(_construct_braidingtensors, ex.args)...)
end
end
_construct_braidingtensors(x) = x

i = 1
translatebraidings = Dict{Any,Any}()
while i <= length(list)
t = list[i]
if _remove_adjoint(TO.gettensorobject(t)) ==
translatebraidings[t] = Expr(:call, GlobalRef(TensorKit, :BraidingTensor))
deleteat!(list, i)
else
i += 1
end
end

unresolved = Any[] # list of indices that we couldn't yet figure out
indexmap = Dict{Any,Any}()
# indexmap[i] contains the expression to resolve the space for index i
for (t, construct_expr) in translatebraidings
obj, leftind, rightind = TO.decomposetensor(t)
length(leftind) == length(rightind) == 2 ||
throw(ArgumentError("The name τ is reserved for the braiding, and should have two input and two output indices."))
if _is_adjoint(obj)
i1b, i2b, = leftind
i2a, i1a, = rightind
else
i2b, i1b, = leftind
i1a, i2a, = rightind
end

obj_and_pos1a = _findindex(i1a, list)
obj_and_pos2a = _findindex(i2a, list)
obj_and_pos1b = _findindex(i1b, list)
obj_and_pos2b = _findindex(i2b, list)
function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed to be a single tensor expression
if TO.istensor(ex)
obj, leftind, rightind = TO.decomposetensor(ex)
if _remove_adjoint(obj) ==
# try to construct a braiding tensor
length(leftind) == length(rightind) == 2 ||
throw(ArgumentError("The name τ is reserved for the braiding, and should have two input and two output indices."))
if _is_adjoint(obj)
i1b, i2b, = leftind
i2a, i1a, = rightind

Check warning on line 121 in src/planar/preprocessors.jl

View check run for this annotation

Codecov / codecov/patch

src/planar/preprocessors.jl#L120-L121

Added lines #L120 - L121 were not covered by tests
else
i2b, i1b, = leftind
i1a, i2a, = rightind
end

if !isnothing(obj_and_pos1a)
indexmap[i1b] = Expr(:call, :space, obj_and_pos1a...)
indexmap[i1a] = Expr(:call, :space, obj_and_pos1a...)
elseif !isnothing(obj_and_pos1b)
indexmap[i1b] = Expr(TO.prime, Expr(:call, :space, obj_and_pos1b...))
indexmap[i1a] = Expr(TO.prime, Expr(:call, :space, obj_and_pos1b...))
foundV1, foundV2 = false, false
if haskey(indexmap, i1a)
V1 = indexmap[i1a]
foundV1 = true
elseif haskey(indexmap, i1b)
V1 = Expr(:call, :dual, indexmap[i1b])
foundV1 = true
end
if haskey(indexmap, i2a)
V2 = indexmap[i2a]
foundV2 = true
elseif haskey(indexmap, i2b)
V2 = Expr(:call, :dual, indexmap[i2b])
foundV2 = true
end
if foundV1 && foundV2
s = gensym()
constructex = Expr(:call, GlobalRef(TensorKit, :BraidingTensor), V1, V2)
push!(preargs, Expr(:(=), s, constructex))
obj = _is_adjoint(obj) ? _add_adjoint(s) : s
success = true
else
success = false

Check warning on line 149 in src/planar/preprocessors.jl

View check run for this annotation

Codecov / codecov/patch

src/planar/preprocessors.jl#L149

Added line #L149 was not covered by tests
end
newex = Expr(:typed_vcat, obj, Expr(:tuple, leftind...),
Expr(:tuple, rightind...))
else
push!(unresolved, (i1a, i1b))
newex = ex
success = true
end

if !isnothing(obj_and_pos2a)
indexmap[i2b] = Expr(:call, :space, obj_and_pos2a...)
indexmap[i2a] = Expr(:call, :space, obj_and_pos2a...)
elseif !isnothing(obj_and_pos2b)
indexmap[i2b] = Expr(TO.prime, Expr(:call, :space, obj_and_pos2b...))
indexmap[i2a] = Expr(TO.prime, Expr(:call, :space, obj_and_pos2b...))
else
push!(unresolved, (i2a, i2b))
# add spaces of the tensor to the indexmap
for (i, l) in enumerate(leftind)
if !haskey(indexmap, l)
indexmap[l] = Expr(:call, :space, obj, i)
end
end
end
# simple loop that tries to resolve as many indices as possible
changed = true
while changed == true
changed = false
i = 1
while i <= length(unresolved)
(i1, i2) = unresolved[i]
if i1 in keys(indexmap)
changed = true
indexmap[i2] = indexmap[i1]
deleteat!(unresolved, i)
elseif i2 in keys(indexmap)
changed = true
indexmap[i1] = indexmap[i2]
deleteat!(unresolved, i)
else
i += 1
for (i, l) in enumerate(rightind)
if !haskey(indexmap, l)
indexmap[l] = Expr(:call, :space, obj, length(leftind) + i)
end
end
end
!isempty(unresolved) &&
throw(ArgumentError("cannot determine the spaces of indices " *
string(tuple(unresolved...)) *
"for the braiding tensors in $(ex)"))

pre = Expr(:block)
for (t, construct_expr) in translatebraidings
obj, leftind, rightind = TO.decomposetensor(t)
if _is_adjoint(obj)
i1b, i2b, = leftind
i2a, i1a, = rightind
else
i2b, i1b, = leftind
i1a, i2a, = rightind
return newex, success
elseif TO.isgeneraltensor(ex)
args = ex.args
newargs = Vector{Any}(undef, length(args))
success = true
for i in 1:length(ex.args)
newargs[i], successa = _construct_braidingtensors!(args[i], preargs, indexmap)
success = success && successa

Check warning on line 175 in src/planar/preprocessors.jl

View check run for this annotation

Codecov / codecov/patch

src/planar/preprocessors.jl#L170-L175

Added lines #L170 - L175 were not covered by tests
end
push!(construct_expr.args, indexmap[i1b])
push!(construct_expr.args, indexmap[i2b])
s = gensym()
push!(pre.args, :(@notensor $s = $construct_expr))
ex = TO.replacetensorobjects(ex) do o, l, r
if o == obj && l == leftind && r == rightind
return obj == ? s : Expr(TO.prime, s)
else
return o
newex = Expr(ex.head, newargs...)
return newex, success

Check warning on line 178 in src/planar/preprocessors.jl

View check run for this annotation

Codecov / codecov/patch

src/planar/preprocessors.jl#L177-L178

Added lines #L177 - L178 were not covered by tests
elseif isexpr(ex, :call) && ex.args[1] == :*
args = ex.args
newargs = Vector{Any}(undef, length(args))
newargs[1] = args[1]
successes = map(i -> false, args)
successes[1] = true
numsuccess = 1
while !all(successes)
for i in 2:length(ex.args)
successes[i] && continue
newargs[i], successa = _construct_braidingtensors!(args[i], preargs,
indexmap)
successes[i] = successa
end
if numsuccess == count(successes)
break

Check warning on line 194 in src/planar/preprocessors.jl

View check run for this annotation

Codecov / codecov/patch

src/planar/preprocessors.jl#L194

Added line #L194 was not covered by tests
end
numsuccess = count(successes)
end
success = numsuccess == length(successes)
newex = Expr(ex.head, newargs...)
return newex, success
elseif isexpr(ex, :call) && ex.args[1] (:+, :-)
args = ex.args
newargs = Vector{Any}(undef, length(args))
newargs[1] = args[1]
success = true
indices = TO.getindices(ex)
for i in 2:length(ex.args)
indexmapa = copy(indexmap)
newargs[i], successa = _construct_braidingtensors!(args[i], preargs, indexmapa)
for l in indices[i]
if !haskey(indexmap, l) && haskey(indexmapa, l)
indexmap[l] = indexmapa[l]

Check warning on line 212 in src/planar/preprocessors.jl

View check run for this annotation

Codecov / codecov/patch

src/planar/preprocessors.jl#L201-L212

Added lines #L201 - L212 were not covered by tests
end
end
success = success && successa

Check warning on line 215 in src/planar/preprocessors.jl

View check run for this annotation

Codecov / codecov/patch

src/planar/preprocessors.jl#L215

Added line #L215 was not covered by tests
end
newex = Expr(ex.head, newargs...)
return newex, success

Check warning on line 218 in src/planar/preprocessors.jl

View check run for this annotation

Codecov / codecov/patch

src/planar/preprocessors.jl#L217-L218

Added lines #L217 - L218 were not covered by tests
else
@show("huh?")
return ex, true

Check warning on line 221 in src/planar/preprocessors.jl

View check run for this annotation

Codecov / codecov/patch

src/planar/preprocessors.jl#L220-L221

Added lines #L220 - L221 were not covered by tests
end
return Expr(:block, pre, ex)
end
_construct_braidingtensors(x) = x

# used by non-planar parser of `@plansor`: remove explicit braiding tensors
function _remove_braidingtensors(ex)
Expand Down
4 changes: 4 additions & 0 deletions src/tensors/braidingtensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
#====================================================================#
"""
struct BraidingTensor{S<:IndexSpace} <: AbstractTensorMap{S, 2, 2}
BraidingTensor(V1::S, V2::S, adjoint::Bool=false) where {S<:IndexSpace}
Specific subtype of [`AbstractTensorMap`](@ref) for representing the braiding tensor that
braids the first input over the second input; its inverse can be obtained as the adjoint.
It holds that `domain(BraidingTensor(V1, V2)) == V1 ⊗ V2` and
`codomain(BraidingTensor(V1, V2)) == V2 ⊗ V1`.
"""
struct BraidingTensor{S<:IndexSpace,A} <: AbstractTensorMap{S,2,2}
V1::S
Expand Down
49 changes: 42 additions & 7 deletions test/planar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,29 +71,29 @@ end

@testset "@planar" verbose = true begin
T = ComplexF64

@testset "contractcheck" begin
V =^2
A = TensorMap(rand, T, V V V)
B = TensorMap(rand, T, V V V')
@tensor C1[i j; k l] := A[i j; m] * B[k l; m]
@tensor contractcheck=true C2[i j; k l] := A[i j; m] * B[k l; m]
@tensor contractcheck = true C2[i j; k l] := A[i j; m] * B[k l; m]
@test C1 C2
B2 = TensorMap(rand, T, V V V) # wrong duality for third space
@test_throws SpaceMismatch("incompatible spaces for m: $V$(V')") begin
@tensor contractcheck = true C3[i j; k l] := A[i j; m] * B2[k l; m]
end

A = TensorMap(rand, T, V V V)
B = TensorMap(rand, T, V V V)
@planar C1[i; j] := A[i; k l] * τ[k l; m n] * B[m n; j]
@planar contractcheck=true C2[i; j] := A[i; k l] * τ[k l; m n] * B[m n; j]
@planar contractcheck = true C2[i; j] := A[i; k l] * τ[k l; m n] * B[m n; j]
@test C1 C2
@test_throws SpaceMismatch("incompatible spaces for l: $V$(V')") begin
@planar contractcheck=true C3[i; j] := A[i; k l] * τ[k l; m n] * B[n j; m]
@test_throws SpaceMismatch("incompatible spaces for m: $V$(V')") begin
@planar contractcheck = true C3[i; j] := A[i; k l] * τ[k l; m n] * B[n j; m]
end
end

@testset "MPS networks" begin
P =^2
Vmps =^12
Expand Down Expand Up @@ -140,6 +140,23 @@ end
conj(x′[1 3; -1])
@test force_planar(ρ2) ρ2′
@test ρ2 ρ3

# Periodic boundary conditions
# ----------------------------
f1 = isomorphism(storagetype(O), fuse(Vmpo^3), Vmpo Vmpo' Vmpo)
f2 = isomorphism(storagetype(O), fuse(Vmpo^3), Vmpo Vmpo' Vmpo)
f1′ = force_planar(f1)
f2′ = force_planar(f2)
@tensor O_periodic1[-1 -2; -3 -4] := O[1 -2; -3 2] * f1[-1; 1 3 4] *
conj(f2[-4; 2 3 4])
@plansor O_periodic2[-1 -2; -3 -4] := O[1 2; -3 6] * f1[-1; 1 3 5] *
conj(f2[-4; 6 7 8]) * τ[2 3; 7 4] *
τ[4 5; 8 -2]
@planar O_periodic′[-1 -2; -3 -4] := O′[1 2; -3 6] * f1′[-1; 1 3 5] *
conj(f2′[-4; 6 7 8]) * τ[2 3; 7 4] *
τ[4 5; 8 -2]
@test O_periodic1 O_periodic2
@test force_planar(O_periodic1) O_periodic′
end

@testset "MERA networks" begin
Expand Down Expand Up @@ -171,4 +188,22 @@ end
end
@test C C′
end

@testset "Issue 93" begin
T = Float64
V1 =^2
V2 =^3
t1 = TensorMap(rand, T, V1 V2)
t2 = TensorMap(rand, T, V2 V1)

tr1 = @planar opt = true t1[a; b] * t2[b; a]
tr2 = @planar opt = true t1[d; a] * t2[b; c] * τ[c b; a d]
tr3 = @planar opt = true t1[d; a] * t2[b; c] * τ[a c; d b]
tr4 = @planar opt = true t1[f; a] * t2[c; d] * τ[d b; c e] * τ[e b; a f]
tr5 = @planar opt = true t1[f; a] * t2[c; d] * τ[d b; c e] * τ[a e; f b]
tr6 = @planar opt = true t1[f; a] * t2[c; d] * τ[c d; e b] * τ[e b; a f]
tr7 = @planar opt = true t1[f; a] * t2[c; d] * τ[c d; e b] * τ[a e; f b]

@test tr1 tr2 tr3 tr4 tr5 tr6 tr7
end
end

0 comments on commit 596b038

Please sign in to comment.