From eaa850646cf487ebb534163376f378613efe9751 Mon Sep 17 00:00:00 2001 From: Paul Brehmer Date: Thu, 11 Jul 2024 15:39:55 +0200 Subject: [PATCH] Parallelize AllSides CTMRG using hacky differentiable @fwdthreads macro, separate diffset into a new file --- src/PEPSKit.jl | 1 + src/algorithms/ctmrg_all_sides.jl | 43 ++++++++++--------- src/utility/diffset.jl | 47 ++++++++++++++++++++ src/utility/util.jl | 71 ++++++++++--------------------- 4 files changed, 93 insertions(+), 69 deletions(-) create mode 100644 src/utility/diffset.jl diff --git a/src/PEPSKit.jl b/src/PEPSKit.jl index 07ee4f40..2dd1057a 100644 --- a/src/PEPSKit.jl +++ b/src/PEPSKit.jl @@ -11,6 +11,7 @@ using ChainRulesCore, Zygote include("utility/util.jl") include("utility/svd.jl") include("utility/rotations.jl") +include("utility/diffset.jl") include("utility/hook_pullback.jl") include("utility/autoopt.jl") diff --git a/src/algorithms/ctmrg_all_sides.jl b/src/algorithms/ctmrg_all_sides.jl index 2d87fb9b..046ae3ee 100644 --- a/src/algorithms/ctmrg_all_sides.jl +++ b/src/algorithms/ctmrg_all_sides.jl @@ -13,35 +13,38 @@ function ctmrg_iter(state, env::CTMRGEnv, alg::CTMRG{:AllSides}) end # Compute enlarged corners and edges for all directions and unit cell entries -function enlarge_corners_edges(state, env::CTMRGEnv) - map(Iterators.product(axes(env.corners)...)) do (dir, r, c) +function enlarge_corners_edges(state, env::CTMRGEnv{C,T}) where {C,T} + Qtype = tensormaptype(spacetype(C), 3, 3, storagetype(C)) + Q = Zygote.Buffer(Array{Qtype,3}(undef, size(env.corners))) + drc_combinations = collect(Iterators.product(axes(env.corners)...)) + @fwdthreads for (dir, r, c) in drc_combinations rprev = _prev(r, size(state, 1)) rnext = _next(r, size(state, 1)) cprev = _prev(c, size(state, 2)) cnext = _next(c, size(state, 2)) - if dir == NORTHWEST - return northwest_corner( + Q[dir, r, c] = if dir == NORTHWEST + northwest_corner( env.edges[WEST, r, cprev], env.corners[NORTHWEST, rprev, cprev], env.edges[NORTH, rprev, c], state[r, c], ) elseif dir == NORTHEAST - return northeast_corner( + northeast_corner( env.edges[NORTH, rprev, c], env.corners[NORTHEAST, rprev, cnext], env.edges[EAST, r, cnext], state[r, c], ) elseif dir == SOUTHEAST - return southeast_corner( + southeast_corner( env.edges[EAST, r, cnext], env.corners[SOUTHEAST, rnext, cnext], env.edges[SOUTH, rnext, c], state[r, c], ) elseif dir == SOUTHWEST - return southwest_corner( + southwest_corner( env.edges[SOUTH, rnext, c], env.corners[SOUTHWEST, rnext, cprev], env.edges[WEST, r, cprev], @@ -49,31 +52,28 @@ function enlarge_corners_edges(state, env::CTMRGEnv) ) end end + + return copy(Q) end # Build projectors from SVD and enlarged corners -function build_projectors(Q, env::CTMRGEnv, alg::ProjectorAlg{A,T}) where {A,T} +function build_projectors(Q, env::CTMRGEnv{C,E}, alg::ProjectorAlg{A,T}) where {C,E,A,T} P_left, P_right = Zygote.Buffer.(projector_type(env.edges)) U, V = Zygote.Buffer.(projector_type(env.edges)) - Stype = tensormaptype( # Corner type but with real numbers - spacetype(env.corners[1]), - 1, - 1, - Matrix{real(scalartype(env.corners[1]))}, - ) + Stype = tensormaptype(spacetype(C), 1, 1, Matrix{real(scalartype(E))}) # Corner type but with real numbers S = Zygote.Buffer(Array{Stype,3}(undef, size(env.corners))) ϵ = 0.0 - rsize, csize = size(env.corners)[2:3] - for dir in 1:4, r in 1:rsize, c in 1:csize + drc_combinations = collect(Iterators.product(axes(env.corners)...)) + @fwdthreads for (dir, r, c) in drc_combinations # Row-column index of next enlarged corner next_rc = if dir == 1 - (r, _next(c, csize)) + (r, _next(c, size(env.corners, 3))) elseif dir == 2 - (_next(r, rsize), c) + (_next(r, size(env.corners, 2)), c) elseif dir == 3 - (r, _prev(c, csize)) + (r, _prev(c, size(env.corners, 3))) elseif dir == 4 - (_prev(r, rsize), c) + (_prev(r, size(env.corners, 2)), c) end # SVD half-infinite environment @@ -122,7 +122,8 @@ end function renormalize_corners_edges(state, env::CTMRGEnv, Q, P_left, P_right) corners::typeof(env.corners) = copy(env.corners) edges::typeof(env.edges) = copy(env.edges) - for c in 1:size(state, 2), r in 1:size(state, 1) + rc_combinations = collect(Iterators.product(axes(state)...)) + @fwdthreads for (r, c) in rc_combinations rprev = _prev(r, size(state, 1)) rnext = _next(r, size(state, 1)) cprev = _prev(c, size(state, 2)) diff --git a/src/utility/diffset.jl b/src/utility/diffset.jl new file mode 100644 index 00000000..80fc14cf --- /dev/null +++ b/src/utility/diffset.jl @@ -0,0 +1,47 @@ +""" + @diffset assign + +Helper macro which allows in-place operations in the forward-pass of Zygote, but +resorts to non-mutating operations in the backwards-pass. The expression `assign` +should assign an object to an pre-existing `AbstractArray` and the use of updating +operators is also possible. This is especially needed when in-place assigning +tensors to unit-cell arrays of environments. +""" +macro diffset(ex) + return esc(parse_ex(ex)) +end +parse_ex(ex) = ex +function parse_ex(ex::Expr) + oppheads = (:(./=), :(.*=), :(.+=), :(.-=)) + opprep = (:(./), :(.*), :(.+), :(.-)) + if ex.head == :macrocall + parse_ex(macroexpand(PEPSKit, ex)) + elseif ex.head in (:(.=), :(=)) && length(ex.args) == 2 && is_indexing(ex.args[1]) + lhs = ex.args[1] + rhs = ex.args[2] + + vname = lhs.args[1] + args = lhs.args[2:end] + quote + $vname = _setindex($vname, $rhs, $(args...)) + end + elseif ex.head in oppheads && length(ex.args) == 2 && is_indexing(ex.args[1]) + hit = findfirst(x -> x == ex.head, oppheads) + rep = opprep[hit] + + lhs = ex.args[1] + rhs = ex.args[2] + + vname = lhs.args[1] + args = lhs.args[2:end] + + quote + $vname = _setindex($vname, $(rep)($lhs, $rhs), $(args...)) + end + else + return Expr(ex.head, parse_ex.(ex.args)...) + end +end + +is_indexing(ex) = false +is_indexing(ex::Expr) = ex.head == :ref diff --git a/src/utility/util.jl b/src/utility/util.jl index e66b059b..3d31c90b 100644 --- a/src/utility/util.jl +++ b/src/utility/util.jl @@ -142,54 +142,6 @@ function ChainRulesCore.rrule(::typeof(_setindex), a::AbstractArray, tv, args... return t, _setindex_pullback end -""" - @diffset assign - -Helper macro which allows in-place operations in the forward-pass of Zygote, but -resorts to non-mutating operations in the backwards-pass. The expression `assign` -should assign an object to an pre-existing `AbstractArray` and the use of updating -operators is also possible. This is especially needed when in-place assigning -tensors to unit-cell arrays of environments. -""" -macro diffset(ex) - return esc(parse_ex(ex)) -end -parse_ex(ex) = ex -function parse_ex(ex::Expr) - oppheads = (:(./=), :(.*=), :(.+=), :(.-=)) - opprep = (:(./), :(.*), :(.+), :(.-)) - if ex.head == :macrocall - parse_ex(macroexpand(PEPSKit, ex)) - elseif ex.head in (:(.=), :(=)) && length(ex.args) == 2 && is_indexing(ex.args[1]) - lhs = ex.args[1] - rhs = ex.args[2] - - vname = lhs.args[1] - args = lhs.args[2:end] - quote - $vname = _setindex($vname, $rhs, $(args...)) - end - elseif ex.head in oppheads && length(ex.args) == 2 && is_indexing(ex.args[1]) - hit = findfirst(x -> x == ex.head, oppheads) - rep = opprep[hit] - - lhs = ex.args[1] - rhs = ex.args[2] - - vname = lhs.args[1] - args = lhs.args[2:end] - - quote - $vname = _setindex($vname, $(rep)($lhs, $rhs), $(args...)) - end - else - return Expr(ex.head, parse_ex.(ex.args)...) - end -end - -is_indexing(ex) = false -is_indexing(ex::Expr) = ex.head == :ref - """ @showtypeofgrad(x) @@ -205,3 +157,26 @@ macro showtypeofgrad(x) end ) end + +""" + @fwdthreads(ex) + +Apply `Threads.@threads` only in the forward pass of the program. + +It works by wrapping the for-loop expression in an if statement where in the forward pass +the loop in computed in parallel using `Threads.@threads`, whereas in the backwards pass +the `Threads.@threads` is omitted in order to make the expression differentiable. +""" +macro fwdthreads(ex) + @assert ex.head === :for "@fwdthreads expects a for loop:\n$ex" + + diffable_ex = quote + if Zygote.isderiving() + $ex + else + Threads.@threads $ex + end + end + + return esc(diffable_ex) +end