Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement multi-threading using OhMyThreads and make it differentiable #70

Merged
merged 32 commits into from
Nov 4, 2024

Conversation

pbrehmer
Copy link
Collaborator

@pbrehmer pbrehmer commented Sep 27, 2024

Here we'll replace the @fwdthreads macro with tmap and foreach calls. Additionally, we will code up reverse rules such that the backwards pass also runs in parallel.

@pbrehmer
Copy link
Collaborator Author

@lkdvos I could use some help on this one... I still find writing rrules very confusing sometimes, so could you maybe take a look at the rrule for dtforeach, once you have time? And while the dtmap rrule already seems to work, it might not be perfect yet :-)

Other than that we have to think about how to pass along the threading kwargs to the dtmap and dtforeach calls. One option is to store those kwargs inside the CTMRG struct but that feels a bit wrong. (Also there are some calls which wouldn't have access to a CTMRG instance.) Not sure what would be the best solution for that. Would global variables work where the user can set that somehow?

I'll also add some rrule tests in the end.

@lkdvos
Copy link
Member

lkdvos commented Sep 30, 2024

I think tforeach is going to be a bit hard, because it has no outputs, and is thus necessarily in-place... I have no clue how the zygote buffer magic works, so I can't really say I know how to deal with that either.

I should have some more time to think this through next week though!

For the global variables, I would maybe suggest ScopedVariables.jl instead, this is a little more flexible and shouldn't incur too much runtime costs

@pbrehmer
Copy link
Collaborator Author

pbrehmer commented Oct 1, 2024

I think tforeach is going to be a bit hard, because it has no outputs, and is thus necessarily in-place... I have no clue how the zygote buffer magic works, so I can't really say I know how to deal with that either.

How about we just stick to tmap then? The only slightly annoying thing is having to separate multiple return values at different indices but that should not incur too much overhead.

For the global variables, I would maybe suggest ScopedVariables.jl instead, this is a little more flexible and shouldn't incur too much runtime costs

Wasn't aware of ScopedValues.jl yet, that looks like a great solution. But I don't quite understand the necessity of a scoped value here since we never need to access the threading settings inside a multi-threaded map, right? In any case, we can probably just have a global scoped Dict with the threading settings that are passed on to the dtmap calls, and that can be mutated by some set_thread_settings function.

Anyways, I will give these things a go and then we can review next week, when you have time :)

Copy link
Member

@lkdvos lkdvos left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I should have elaborated a bit more on what I had in mind with the scoped values: I would keep the threading strategies in scoped values, such that they can always be accessed as if they are global values. (this prevents them from bloating all of our algorithms)
The benefit of having them as a scoped value however means that users could still change them by calling the peps function from within a scope with a modified scoped value, thus changing the scheduler.

src/utility/diffable_threads.jl Outdated Show resolved Hide resolved
src/utility/diffable_threads.jl Outdated Show resolved Hide resolved
src/utility/diffable_threads.jl Outdated Show resolved Hide resolved
@pbrehmer
Copy link
Collaborator Author

I was trying to fix the Zygote error but with no luck, I really don't know how to handle the NoTangents (dA is an Array{NoTangent,3}) as they are being converted to Nothing. I'll copy the error here for future reference:

ERROR: MethodError: no method matching length(::Nothing)
The function `length` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  length(::Zygote.Grads, Any...; kwargs...)
   @ Zygote ~/.julia/packages/MacroTools/Cf2ok/src/examples/forward.jl:17
  length(::Combinatorics.Partition)
   @ Combinatorics ~/.julia/packages/Combinatorics/Udg6X/src/youngdiagrams.jl:8
  length(::Base.MethodSpecializations)
   @ Base reflection.jl:1317
  ...

Stacktrace:
  [1] productfunc(xs::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}, dy::Array{Nothing, 3})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/lib/array.jl:278
  [2] (::Zygote.var"#collect_product_pullback#744"{Base.Iterators.ProductIterator{Tuple{…}}})(dy::Array{Nothing, 3})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/lib/array.jl:300
  [3] (::Zygote.var"#2941#back#745"{Zygote.var"#collect_product_pullback#744"{…}})(Δ::Array{Nothing, 3})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
  [4] ctmrg_renormalize
    @ ~/repos/PEPSKit.jl/src/algorithms/ctmrg/ctmrg.jl:383 [inlined]
  [5] (::Zygote.Pullback{Tuple{…}, Any})(Δ::CTMRGEnv{TrivialTensorMap{…}, TrivialTensorMap{…}})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface2.jl:0
  [6] ctmrg_iter
    @ ~/repos/PEPSKit.jl/src/algorithms/ctmrg/ctmrg.jl:163 [inlined]
  [7] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{CTMRGEnv{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface2.jl:0
  [8] #221
    @ ~/repos/PEPSKit.jl/src/algorithms/ctmrg/ctmrg.jl:118 [inlined]
  [9] (::Zygote.Pullback{Tuple{…}, Any})(Δ::CTMRGEnv{TrivialTensorMap{…}, TrivialTensorMap{…}})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface2.jl:0
 [10] #35
    @ ~/.julia/packages/LoggingExtras/cFgEq/src/verbosity.jl:117 [inlined]
 [11] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::CTMRGEnv{TrivialTensorMap{…}, TrivialTensorMap{…}})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#ad_pullback#61"{Tuple{…}, Zygote.Pullback{…}})(Δ::CTMRGEnv{TrivialTensorMap{…}, TrivialTensorMap{…}})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/chainrules.jl:264
 [13] with_logger_pullback
    @ ~/.julia/packages/ChainRules/vdf7M/src/rulesets/Base/CoreLogging.jl:12 [inlined]
 [14] (::Zygote.ZBack{ChainRules.var"#with_logger_pullback#862"{…}})(dy::CTMRGEnv{TrivialTensorMap{…}, TrivialTensorMap{…}})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/chainrules.jl:212
 [15] #withlevel#34
    @ ~/.julia/packages/LoggingExtras/cFgEq/src/verbosity.jl:113 [inlined]
 [16] (::Zygote.Pullback{Tuple{…}, Any})(Δ::CTMRGEnv{TrivialTensorMap{…}, TrivialTensorMap{…}})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface2.jl:0
 [17] withlevel
    @ ~/.julia/packages/LoggingExtras/cFgEq/src/verbosity.jl:107 [inlined]
 [18] (::Zygote.Pullback{Tuple{…}, Any})(Δ::CTMRGEnv{TrivialTensorMap{…}, TrivialTensorMap{…}})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface2.jl:0
 [19] withlevel
    @ ~/.julia/packages/LoggingExtras/cFgEq/src/verbosity.jl:107 [inlined]
 [20] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::CTMRGEnv{TrivialTensorMap{…}, TrivialTensorMap{…}})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface2.jl:0
 [21] leading_boundary
    @ ~/repos/PEPSKit.jl/src/algorithms/ctmrg/ctmrg.jl:115 [inlined]
 [22] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::CTMRGEnv{TrivialTensorMap{…}, TrivialTensorMap{…}})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface2.jl:0
 [23] #37
    @ ./REPL[24]:2 [inlined]
 [24] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface2.jl:0
 [25] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface.jl:91
 [26] withgradient(f::Function, args::InfinitePEPS{TrivialTensorMap{ComplexSpace, 1, 4, Matrix{ComplexF64}}})
    @ Zygote ~/.julia/packages/Zygote/NRp5C/src/compiler/interface.jl:213
 [27] top-level scope
    @ REPL[24]:1

@lkdvos
Copy link
Member

lkdvos commented Oct 29, 2024

I'm honestly not so sure how no one has ever run into this, but I seem to have been able to circumvent the issue by just not differentiating through the collect(Iterators.product())calls. It might have something to do with our tmapgradient returning an array of nothing instead of simply nothing, but I don't really have the time to investigate that and this seems to work.

@lkdvos lkdvos force-pushed the pb-diffable-threads branch from cbc1fe7 to f3d7192 Compare October 29, 2024 21:57
@lkdvos lkdvos force-pushed the pb-diffable-threads branch from f3d7192 to cab6c69 Compare October 29, 2024 23:12
Copy link

codecov bot commented Oct 30, 2024

Codecov Report

Attention: Patch coverage is 95.45455% with 3 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/operators/infinitepepo.jl 0.00% 2 Missing ⚠️
src/PEPSKit.jl 85.71% 1 Missing ⚠️
Files with missing lines Coverage Δ
src/algorithms/ctmrg/ctmrg.jl 92.30% <100.00%> (+0.95%) ⬆️
src/algorithms/ctmrg/gaugefix.jl 94.87% <100.00%> (+0.13%) ⬆️
src/algorithms/toolbox.jl 96.72% <100.00%> (+0.05%) ⬆️
src/environments/ctmrg_environments.jl 68.24% <100.00%> (+1.34%) ⬆️
src/states/infinitepeps.jl 67.77% <100.00%> (+1.49%) ⬆️
src/utility/diffable_threads.jl 100.00% <100.00%> (ø)
src/utility/util.jl 54.21% <100.00%> (-3.09%) ⬇️
src/PEPSKit.jl 87.50% <85.71%> (-12.50%) ⬇️
src/operators/infinitepepo.jl 18.51% <0.00%> (-0.35%) ⬇️

@pbrehmer
Copy link
Collaborator Author

I noticed that the default ntasks is always set to one since during pre-compilation Threads.nthreads always returns one. That's why I defaulted to ntasks=4 since on most machines (if Julia is started with the corresponding number of threads) this will multi-thread CTMRG in the four spatial directions. What do you think about that?

Other than that this seems mergeable, thanks a lot for pinpointing and circumventing the error! It is really weird how this Zygote problem never came up before...

src/PEPSKit.jl Outdated Show resolved Hide resolved
@pbrehmer pbrehmer requested a review from lkdvos November 1, 2024 17:07
@pbrehmer
Copy link
Collaborator Author

pbrehmer commented Nov 1, 2024

Whoops somehow the tests are starting to fail which was not the case before the merge, need to investigate what's happening...

So somehow there are some segmentation faults appearing in the Julia LTS tests probably related to the multi-threaded reverse rule. This doesn't seem to be a problem on the latest Julia channel. What is weird is that even when I run the tests locally on Julia 1.10 they won't fail on my machine. I find this quite confusing.

@lkdvos What do you think is the best way to proceed here? Should we restrict the number of threads for the LTS tests to one?

@lkdvos
Copy link
Member

lkdvos commented Nov 1, 2024

I'm definitely not a big fan of fixing problems by just turning off the tests, so let's maybe try an see if we can reproduce this anyways? Given that it happens on all platforms, presumably it is a Julia specific problem, but I would still love to either know what causes it, so we can avoid it, or just fix it in the first place...

I guess the reason it only started turning up now is that we were not testing on julia 1.10 before, and instead had 1.9 an 1.11

@pbrehmer
Copy link
Collaborator Author

pbrehmer commented Nov 1, 2024

I'm definitely not a big fan of fixing problems by just turning off the tests, so let's maybe try an see if we can reproduce this anyways?

I definitely agree that turning off tests is not the way to go. However, I haven't managed to reproduce it in any way on Ubuntu with Julia 1.10.6 (and also other Julia versions) - the tests just run through without erroring or causing a segmentation fault, regardless of the number of threads.

Given that it happens on all platforms, presumably it is a Julia specific problem

Weirdly, the examples run through on the LTS on macOS (whereas they do not on Ubuntu and Windows). So there might also be some interaction with the OS and multi-threading?

Unfortunately, I don't have access to macOS or Windows machines so I don't really know how to approach this or narrow this down. Also, I don't have that much more time to spend on this (and probably neither do you), so that's why I was looking for a quick solution :-) Anyways, if you come up with ideas to try out, let me know and I can revisit this next week.

@pbrehmer
Copy link
Collaborator Author

pbrehmer commented Nov 4, 2024

Just a quick thought: Could it be that the OhMyThreads multi-threading is interfering with OpenBLAS multi-threading? If the CI runners have only 4 cores in total then choosing JULIA_NUM_THREADS=4 might lead to problems if OpenBLAS also tries to claim more than one thread.

@lkdvos
Copy link
Member

lkdvos commented Nov 4, 2024

This is unlikely to be the case (but of course without an actual solution I can't exclude anything), because of Julia's task-based parallelism, and this basically being handled by the operating system. In principle there is nothing stopping you from running 20 tasks on a single threaded setup, the OS will switch between these tasks as usual, and everything should still work the same. Of course, there will only ever be one task running in that case, while you still pay the overhead of the tasks and context switching, so typically this hinders performance a bit, but it shouldn't break anything.

I'll try to spend some more time to investigate this week, I definitely don't feel comfortable releasing something that breaks on the LTS on all platforms, so if we can't figure this out we might have to restrict to v1.11, which I'd rather avoid

@pbrehmer
Copy link
Collaborator Author

pbrehmer commented Nov 4, 2024

Oh well, you're right. I just tried to come up with some reason why the tests run through on all my Ubuntu machines on 1.10.6 and why they don't on the CI runners. Also the fact that the macOS examples do run through on 1.10.6 is very weird. This seems to happen deterministically (I reran the CI earlier) but it is somehow platform-specific.

Are the tests running through for you locally on an LTS release? I'm currently stumped on this one.

@lkdvos
Copy link
Member

lkdvos commented Nov 4, 2024

If this runs through, I'll re-enable the tests for all OS, and then it is definitely my fault. I think I might have screwed up something with the caching of precompiled data, which might have not included the julia version in the hash, thus possibly leading to these weird kinds of errors. (Just guessing, still need to check)
If this is the case, I'll fix and merge?

@pbrehmer
Copy link
Collaborator Author

pbrehmer commented Nov 4, 2024

Alright, I see no more segmentation faults, thanks for taking a look and fixing it! This looks mergeable now :-)

@lkdvos
Copy link
Member

lkdvos commented Nov 4, 2024

Let's just wait a second until the tests actually all turn green (seems like they are not running now for some reason that I don't understand)

@lkdvos lkdvos force-pushed the pb-diffable-threads branch 2 times, most recently from 9cebff2 to a9c0b82 Compare November 4, 2024 16:52
@lkdvos lkdvos force-pushed the pb-diffable-threads branch from a9c0b82 to 3ea4833 Compare November 4, 2024 16:52
@lkdvos lkdvos force-pushed the pb-diffable-threads branch from 46edd69 to 4844ea0 Compare November 4, 2024 17:51
Copy link
Member

@lkdvos lkdvos left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If everything now turns green, good to go for me

@lkdvos lkdvos enabled auto-merge (squash) November 4, 2024 18:12
@lkdvos lkdvos merged commit 3a9eb40 into master Nov 4, 2024
27 checks passed
@pbrehmer pbrehmer deleted the pb-diffable-threads branch November 4, 2024 19:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants