From 8a2df708e0e2f23fef67c5179726a0a8039f29e7 Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Fri, 1 Dec 2023 17:04:14 +0100 Subject: [PATCH] some more planar fixes --- src/planar/analyzers.jl | 4 ++++ src/planar/preprocessors.jl | 21 +++++++++++++++++---- src/tensors/braidingtensor.jl | 23 +++++++++++++++++++++++ test/planar.jl | 14 +++++++------- 4 files changed, 51 insertions(+), 11 deletions(-) diff --git a/src/planar/analyzers.jl b/src/planar/analyzers.jl index b91ff2ba..4d22cf0a 100644 --- a/src/planar/analyzers.jl +++ b/src/planar/analyzers.jl @@ -36,6 +36,10 @@ function get_possible_planar_indices(ex) end end return inds + elseif isexpr(ex, :call) && ex.args[1] == :/ + return get_possible_planar_indices(ex.args[2]) + elseif isexpr(ex, :call) && ex.args[1] == :\ + return get_possible_planar_indices(ex.args[3]) else return Any[] end diff --git a/src/planar/preprocessors.jl b/src/planar/preprocessors.jl index 3d6abecc..9fbd51a5 100644 --- a/src/planar/preprocessors.jl +++ b/src/planar/preprocessors.jl @@ -110,7 +110,9 @@ end _construct_braidingtensors(x) = x function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed to be a single tensor expression - if TO.istensor(ex) + if TO.isscalarexpr(ex) + return ex, true + elseif TO.istensor(ex) obj, leftind, rightind = TO.decomposetensor(ex) if _remove_adjoint(obj) == :τ # try to construct a braiding tensor @@ -216,9 +218,14 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t end newex = Expr(ex.head, newargs...) return newex, success + elseif isexpr(ex, :call) && ex.args[1] == :/ && length(ex.args) == 3 + newarg, success = _construct_braidingtensors!(ex.args[2], preargs, indexmap) + return Expr(:call, :/, newarg, ex.args[3]), success + elseif isexpr(ex, :call) && ex.args[1] == :\ && length(ex.args) == 3 + newarg, success = _construct_braidingtensors!(ex.args[3], preargs, indexmap) + return Expr(:call, :\, ex.args[2], newarg), success else - @show("huh?") - return ex, true + error("unexpected expression $ex") end end @@ -520,8 +527,14 @@ function _extract_contraction_pairs(rhs, lhs, pre, temporaries) for a in rhs.args[2:end]] return Expr(rhs.head, rhs.args[1], args...) + elseif isexpr(rhs, :call) && rhs.args[1] == :/ + newarg = _extract_contraction_pairs(rhs.args[2], lhs, pre, temporaries) + return Expr(:call, :/, newarg, rhs.args[3]) + elseif isexpr(rhs, :call) && rhs.args[1] == :\ + newarg = _extract_contraction_pairs(rhs.args[3], lhs, pre, temporaries) + return Expr(:call, :\, rhs.args[2], newarg) else - throw(ArgumentError("unknown tensor expression")) + throw(ArgumentError("unknown tensor expression $ex")) end end diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index 2ee46f66..ab945e81 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -271,6 +271,18 @@ function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, end return C end +function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, + A::BraidingTensor{S}, + (oindA, cindA)::Index2Tuple{2,2}, + B::BraidingTensor{S}, + (cindB, oindB)::Index2Tuple{2,2}, + (p1, p2)::Index2Tuple{N₁,N₂}, + α::Number, β::Number, + backend::Backend...) where {S,N₁,N₂} + return planarcontract!(C, copy(A), (oindA, cindA), B, (cindB, oindB), (p1, p2), α, β, + backend...) +end + function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, A::AbstractTensorMap{S}, (oindA, cindA)::Index2Tuple{N₃,2}, @@ -317,6 +329,17 @@ function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, end return C end +function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, + A::BraidingTensor{S}, + (oindA, cindA)::Index2Tuple{2,2}, + B::BraidingTensor{S}, + (cindB, oindB)::Index2Tuple{2,2}, + (p1, p2)::Index2Tuple{N₁,N₂}, + α::Number, β::Number, + backend::Backend...) where {S,N₁,N₂} + return planarcontract!(C, copy(A), (oindA, cindA), B, (cindB, oindB), (p1, p2), α, β, + backend...) +end # Fallback cases for planarcontract! # TODO: implement specialised cases for contracting 0, 1, 3 and 4 indices diff --git a/test/planar.jl b/test/planar.jl index 6254711d..5cbe1689 100644 --- a/test/planar.jl +++ b/test/planar.jl @@ -196,13 +196,13 @@ end t1 = TensorMap(rand, T, V1 ← V2) t2 = TensorMap(rand, T, V2 ← V1) - tr1 = @planar opt = true t1[a; b] * t2[b; a] - tr2 = @planar opt = true t1[d; a] * t2[b; c] * τ[c b; a d] - tr3 = @planar opt = true t1[d; a] * t2[b; c] * τ[a c; d b] - tr4 = @planar opt = true t1[f; a] * t2[c; d] * τ[d b; c e] * τ[e b; a f] - tr5 = @planar opt = true t1[f; a] * t2[c; d] * τ[d b; c e] * τ[a e; f b] - tr6 = @planar opt = true t1[f; a] * t2[c; d] * τ[c d; e b] * τ[e b; a f] - tr7 = @planar opt = true t1[f; a] * t2[c; d] * τ[c d; e b] * τ[a e; f b] + tr1 = @planar opt = true t1[a; b] * t2[b; a]/2 + tr2 = @planar opt = true t1[d; a] * t2[b; c] * 1/2 * τ[c b; a d] + tr3 = @planar opt = true t1[d; a] * t2[b; c] * τ[a c; d b] /2 + tr4 = @planar opt = true t1[f; a] * 1/2 * t2[c; d] * τ[d b; c e] * τ[e b; a f] + tr5 = @planar opt = true t1[f; a] * t2[c; d]/2 * τ[d b; c e] * τ[a e; f b] + tr6 = @planar opt = true t1[f; a] * t2[c; d] * τ[c d; e b]/2 * τ[e b; a f] + tr7 = @planar opt = true t1[f; a] * t2[c; d] * (τ[c d; e b] * τ[a e; f b] /2) @test tr1 ≈ tr2 ≈ tr3 ≈ tr4 ≈ tr5 ≈ tr6 ≈ tr7 end