Skip to content

Commit

Permalink
Parallelize AllSides CTMRG using hacky differentiable @fwdthreads mac…
Browse files Browse the repository at this point in the history
…ro, separate diffset into a new file
  • Loading branch information
pbrehmer committed Jul 11, 2024
1 parent 8e9ca87 commit eaa8506
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 69 deletions.
1 change: 1 addition & 0 deletions src/PEPSKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using ChainRulesCore, Zygote
include("utility/util.jl")
include("utility/svd.jl")
include("utility/rotations.jl")
include("utility/diffset.jl")
include("utility/hook_pullback.jl")
include("utility/autoopt.jl")

Expand Down
43 changes: 22 additions & 21 deletions src/algorithms/ctmrg_all_sides.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,67 +13,67 @@ function ctmrg_iter(state, env::CTMRGEnv, alg::CTMRG{:AllSides})
end

# Compute enlarged corners and edges for all directions and unit cell entries
function enlarge_corners_edges(state, env::CTMRGEnv)
map(Iterators.product(axes(env.corners)...)) do (dir, r, c)
function enlarge_corners_edges(state, env::CTMRGEnv{C,T}) where {C,T}
Qtype = tensormaptype(spacetype(C), 3, 3, storagetype(C))
Q = Zygote.Buffer(Array{Qtype,3}(undef, size(env.corners)))
drc_combinations = collect(Iterators.product(axes(env.corners)...))
@fwdthreads for (dir, r, c) in drc_combinations
rprev = _prev(r, size(state, 1))
rnext = _next(r, size(state, 1))
cprev = _prev(c, size(state, 2))
cnext = _next(c, size(state, 2))
if dir == NORTHWEST
return northwest_corner(
Q[dir, r, c] = if dir == NORTHWEST
northwest_corner(
env.edges[WEST, r, cprev],
env.corners[NORTHWEST, rprev, cprev],
env.edges[NORTH, rprev, c],
state[r, c],
)
elseif dir == NORTHEAST
return northeast_corner(
northeast_corner(
env.edges[NORTH, rprev, c],
env.corners[NORTHEAST, rprev, cnext],
env.edges[EAST, r, cnext],
state[r, c],
)
elseif dir == SOUTHEAST
return southeast_corner(
southeast_corner(
env.edges[EAST, r, cnext],
env.corners[SOUTHEAST, rnext, cnext],
env.edges[SOUTH, rnext, c],
state[r, c],
)
elseif dir == SOUTHWEST
return southwest_corner(
southwest_corner(
env.edges[SOUTH, rnext, c],
env.corners[SOUTHWEST, rnext, cprev],
env.edges[WEST, r, cprev],
state[r, c],
)
end
end

return copy(Q)
end

# Build projectors from SVD and enlarged corners
function build_projectors(Q, env::CTMRGEnv, alg::ProjectorAlg{A,T}) where {A,T}
function build_projectors(Q, env::CTMRGEnv{C,E}, alg::ProjectorAlg{A,T}) where {C,E,A,T}
P_left, P_right = Zygote.Buffer.(projector_type(env.edges))
U, V = Zygote.Buffer.(projector_type(env.edges))
Stype = tensormaptype( # Corner type but with real numbers
spacetype(env.corners[1]),
1,
1,
Matrix{real(scalartype(env.corners[1]))},
)
Stype = tensormaptype(spacetype(C), 1, 1, Matrix{real(scalartype(E))}) # Corner type but with real numbers
S = Zygote.Buffer(Array{Stype,3}(undef, size(env.corners)))
ϵ = 0.0
rsize, csize = size(env.corners)[2:3]
for dir in 1:4, r in 1:rsize, c in 1:csize
drc_combinations = collect(Iterators.product(axes(env.corners)...))
@fwdthreads for (dir, r, c) in drc_combinations
# Row-column index of next enlarged corner
next_rc = if dir == 1
(r, _next(c, csize))
(r, _next(c, size(env.corners, 3)))
elseif dir == 2
(_next(r, rsize), c)
(_next(r, size(env.corners, 2)), c)
elseif dir == 3
(r, _prev(c, csize))
(r, _prev(c, size(env.corners, 3)))
elseif dir == 4
(_prev(r, rsize), c)
(_prev(r, size(env.corners, 2)), c)
end

# SVD half-infinite environment
Expand Down Expand Up @@ -122,7 +122,8 @@ end
function renormalize_corners_edges(state, env::CTMRGEnv, Q, P_left, P_right)
corners::typeof(env.corners) = copy(env.corners)
edges::typeof(env.edges) = copy(env.edges)
for c in 1:size(state, 2), r in 1:size(state, 1)
rc_combinations = collect(Iterators.product(axes(state)...))
@fwdthreads for (r, c) in rc_combinations
rprev = _prev(r, size(state, 1))
rnext = _next(r, size(state, 1))
cprev = _prev(c, size(state, 2))
Expand Down
47 changes: 47 additions & 0 deletions src/utility/diffset.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
@diffset assign
Helper macro which allows in-place operations in the forward-pass of Zygote, but
resorts to non-mutating operations in the backwards-pass. The expression `assign`
should assign an object to an pre-existing `AbstractArray` and the use of updating
operators is also possible. This is especially needed when in-place assigning
tensors to unit-cell arrays of environments.
"""
macro diffset(ex)
return esc(parse_ex(ex))
end
parse_ex(ex) = ex
function parse_ex(ex::Expr)
oppheads = (:(./=), :(.*=), :(.+=), :(.-=))
opprep = (:(./), :(.*), :(.+), :(.-))
if ex.head == :macrocall
parse_ex(macroexpand(PEPSKit, ex))
elseif ex.head in (:(.=), :(=)) && length(ex.args) == 2 && is_indexing(ex.args[1])
lhs = ex.args[1]
rhs = ex.args[2]

vname = lhs.args[1]
args = lhs.args[2:end]
quote
$vname = _setindex($vname, $rhs, $(args...))
end
elseif ex.head in oppheads && length(ex.args) == 2 && is_indexing(ex.args[1])
hit = findfirst(x -> x == ex.head, oppheads)
rep = opprep[hit]

lhs = ex.args[1]
rhs = ex.args[2]

vname = lhs.args[1]
args = lhs.args[2:end]

quote
$vname = _setindex($vname, $(rep)($lhs, $rhs), $(args...))
end
else
return Expr(ex.head, parse_ex.(ex.args)...)
end
end

is_indexing(ex) = false
is_indexing(ex::Expr) = ex.head == :ref
71 changes: 23 additions & 48 deletions src/utility/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,54 +142,6 @@ function ChainRulesCore.rrule(::typeof(_setindex), a::AbstractArray, tv, args...
return t, _setindex_pullback
end

"""
@diffset assign
Helper macro which allows in-place operations in the forward-pass of Zygote, but
resorts to non-mutating operations in the backwards-pass. The expression `assign`
should assign an object to an pre-existing `AbstractArray` and the use of updating
operators is also possible. This is especially needed when in-place assigning
tensors to unit-cell arrays of environments.
"""
macro diffset(ex)
return esc(parse_ex(ex))
end
parse_ex(ex) = ex
function parse_ex(ex::Expr)
oppheads = (:(./=), :(.*=), :(.+=), :(.-=))
opprep = (:(./), :(.*), :(.+), :(.-))
if ex.head == :macrocall
parse_ex(macroexpand(PEPSKit, ex))
elseif ex.head in (:(.=), :(=)) && length(ex.args) == 2 && is_indexing(ex.args[1])
lhs = ex.args[1]
rhs = ex.args[2]

vname = lhs.args[1]
args = lhs.args[2:end]
quote
$vname = _setindex($vname, $rhs, $(args...))
end
elseif ex.head in oppheads && length(ex.args) == 2 && is_indexing(ex.args[1])
hit = findfirst(x -> x == ex.head, oppheads)
rep = opprep[hit]

lhs = ex.args[1]
rhs = ex.args[2]

vname = lhs.args[1]
args = lhs.args[2:end]

quote
$vname = _setindex($vname, $(rep)($lhs, $rhs), $(args...))
end
else
return Expr(ex.head, parse_ex.(ex.args)...)
end
end

is_indexing(ex) = false
is_indexing(ex::Expr) = ex.head == :ref

"""
@showtypeofgrad(x)
Expand All @@ -205,3 +157,26 @@ macro showtypeofgrad(x)
end
)
end

"""
@fwdthreads(ex)
Apply `Threads.@threads` only in the forward pass of the program.
It works by wrapping the for-loop expression in an if statement where in the forward pass
the loop in computed in parallel using `Threads.@threads`, whereas in the backwards pass
the `Threads.@threads` is omitted in order to make the expression differentiable.
"""
macro fwdthreads(ex)
@assert ex.head === :for "@fwdthreads expects a for loop:\n$ex"

diffable_ex = quote
if Zygote.isderiving()
$ex
else
Threads.@threads $ex
end
end

return esc(diffable_ex)
end

0 comments on commit eaa8506

Please sign in to comment.