Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ChainRules #366

Merged
merged 35 commits into from
May 28, 2020
Merged

Add ChainRules #366

merged 35 commits into from
May 28, 2020

Conversation

oxinabox
Copy link
Member

@oxinabox oxinabox commented Oct 7, 2019

This replaces #291

The bits from that OP that still matter

Step 1) Change Zygote to check for chainrules before doing its normal stuff,
and adapt the stuff it gets back from chainrules to play nice with Zygote's expectations

Step 2) adapt Zygote more deeply, so it can take full advantage of thunks etc.

This PR is Step 1.

TODO: workout why this seems to segfault for me.

@oxinabox
Copy link
Member Author

oxinabox commented Oct 7, 2019

@MikeInnes any idea why this segfaults?

@MikeInnes
Copy link
Member

It could be a Zygote compiler issue, but given that you haven't touched that, the second most likely option is that it's just a stackoverflow.

@oxinabox
Copy link
Member Author

oxinabox commented Oct 8, 2019

What makes you think stackoverflow?
I don't see anything in the logs pointing towards that

@MikeInnes
Copy link
Member

IME, Julia tends to segfault due to stack overflows. Doesn't happen every time, but often.

src/compiler/interface.jl Outdated Show resolved Hide resolved
@oxinabox
Copy link
Member Author

It works!

1 of the blackslists can be removed immediately with JuliaDiff/ChainRules.jl#124

src/compiler/interface2.jl Outdated Show resolved Hide resolved
test/chainrules.jl Outdated Show resolved Hide resolved
test/chainrules.jl Outdated Show resolved Hide resolved
@oxinabox oxinabox force-pushed the ox/chainrules_step1b branch from e47440a to 74396b4 Compare October 19, 2019 14:02
@oxinabox oxinabox force-pushed the ox/chainrules_step1b branch from 310ae70 to bc8ac2c Compare October 21, 2019 09:25
@MikeInnes
Copy link
Member

I'm concerned about the performance implications of this, in its current form. In particular we have to be pretty careful to manage the call stack; recursing many times into _pullback will likely trigger recursion limiting heuristics that cause type inference to give up. (This is particularly pernicious because it doesn't show up in @code_typed.)

Also, presumably many of Zygote's current adjoints are now redundant, given this PR. It'd be nice to remove those to show the overall impact on Zygote's codebase (hopefully a significant net improvement).

@oxinabox
Copy link
Member Author

oxinabox commented Oct 21, 2019

I'm concerned about the performance implications of this, in its current form. In particular we have to be pretty careful to manage the call stack; recursing many times into _pullback will likely trigger recursion limiting heuristics that cause type inference to give up. (This is particularly pernicious because it doesn't show up in @code_typed.)

Which bits do you think are adding levels?
We don't recurse into _pullback anywhere in this PR.

How every you count it, adding deeper integation and usiing ChainRules types (rather than zexterning everythiing) would cut that to 1 I think,
and also give benifit of not unthunking everything.

Also, presumably many of Zygote's current adjoints are now redundant, given this PR. It'd be nice to remove those to show the overall impact on Zygote's codebase (hopefully a significant net improvement).

Easiest way to do would be to add convience functions to ChainRulesCore
for importing Zygote Rules's @adjoint defintions,
and then transferiing basically everything over.
Checking as one goes.
Then look at the method overwrite warnings.
JuliaDiff/ChainRulesCore.jl#44

I'ld rather not be doing that alone, and I'ld rather not have that blocking this PR.
I think it can be done in a follow up.
But if we need to use ChainRules types deeply then maybe not.

@MikeInnes
Copy link
Member

Which bits do you think are adding levels? We don't recurse into _pullback anywhere in this PR.

If the user-defined call tree is foo -> bar -> baz, the transformed tree under this PR is _pullback(foo) -> _pullback_sct(foo) -> _pullback(bar) -> _pullback_sct(bar) -> _pullback(baz) -> _pullback_sct(baz) (since _pullback_sct recursively calls _pullback for each call inside the function it transforms).

Previously, it would have been _pullback_sct(foo) -> _pullback_sct(bar) -> _pullback_sct(baz). Because we have a type inference hack that lets us trick the compiler into seeing foo -> bar -> baz again here, this doesn't kill the compiler. Whereas now the callstack has three recursions into the same function _pullback, which can cause issues.

@oxinabox
Copy link
Member Author

Because we have a type inference hack that lets us trick the compiler into seeing foo -> bar -> baz again here, this doesn't kill the compiler.

Where can I learn about this type inference hack?
And is the solution to just inline _pullback_sct back into _pullback by hand?

@MikeInnes
Copy link
Member

I don't think it's documented, but you just set a flag on the codeinfo, as here.

I think manual inlining / having one uber-function for _pullback is probably the way to go (it can still be factored of course, as long as any function calls are compile-time).

@oxinabox
Copy link
Member Author

oxinabox commented Oct 25, 2019

Re inference
I am starting to feel like the easiest way
might be to more down to the Mike’s Little IR level
and also to just do a bunch of edge hooking myself

rather than using the intended if nothing==res=ChainRules.rrule(…_
thing I can just check the method it would hit,
and then check if it is the fallback,
and if not stick a call into it and and attach an edge
Or if it is hitting the fallback then attack an edge to the method table instead

2 problems:

  1. this removes ability to bail out of a chainrule via returning nothing
  2. this doesn’t work in 1.0

So I guess I have to put the whole thing including the decision of if to use a ChainRule or source2source
into a generated function that returns a CodeInfo,
so I can set the
method_for_inference_limit_heuristics flag on that.

But if I do that, don't I end up not having the compile know about what functions are called,
and so new ChainRules won't be picked up, without hand connecting edges,
and so it won't be possible to make it work in 1.0.

Is there a way to set method_for_inference_limit_heuristics flag
without having to use a generated function returning a CodeInfo?

Maybe I could overload https://github.com/JuliaLang/julia/blob/a63f2e9b26751ef2d1522fa2634ee9d56db8528d/base/compiler/utilities.jl#L138|
on _pullback as a plain function.

@oxinabox
Copy link
Member Author

Maybe could go the otherway and make pullback_source2source(ctx, f, args....) set its method_for_inference_limit_heuristics to be pullback(ctx, f, args....)
rather than being set to f(args...) like it currently seems to be?

src/compiler/interface2.jl Outdated Show resolved Hide resolved
@oxinabox
Copy link
Member Author

oxinabox commented Jan 4, 2020

@MikeInnes @dhairyagandhi96 how are we going to resolve this?
Will my plan from above work?
#366 (comment)
Do we need to get compiler people in to a Google meet call to work though things?

@oxinabox oxinabox force-pushed the ox/chainrules_step1b branch from 951dac6 to 7087cbd Compare January 13, 2020 16:32
@oxinabox
Copy link
Member Author

This is much nicer now.
Because some of the things we were blacklisting are gone,
others are fixed,
and I realized we don't need to actually get rid of most of ChainRule's types since they define + in the way accum needs.

oxinabox and others added 22 commits May 28, 2020 10:51
Update docs/src/adjoints.md

Co-Authored-By: Nick Robinson <[email protected]>

fix typo in docs

delete debug printing
linkl top chainrules issue about fastmath
fix typo

Pin IRTools to 0.3.2 because FluxML/IRTools.jl#58
Co-authored-by: Pietro Vertechi <[email protected]>

Update docs/src/adjoints.md
make kwargs work

Update src/compiler/chainrules.jl

Update src/compiler/chainrules.jl

and chainrules kwarg tests
Co-authored-by: Carlo Lucibello <[email protected]>
Decide if is kwfunc at compile time.
…es (#1)

* ChainRules pullbacks always have 1 input JuliaDiff/ChainRulesCore.jl#152

* swap to version of chainrules that don't use multiarg pullbacks

* update tests

* make so don't need custom rule anymore

* add comment

* Update src/compiler/chainrules.jl

Co-authored-by: willtebbutt <[email protected]>

Co-authored-by: willtebbutt <[email protected]>
@oxinabox oxinabox force-pushed the ox/chainrules_step1b branch from 25b9aa6 to 06ad874 Compare May 28, 2020 10:47
@oxinabox oxinabox force-pushed the ox/chainrules_step1b branch from 06ad874 to 41f4c17 Compare May 28, 2020 11:16
@oxinabox
Copy link
Member Author

bors r+

@bors
Copy link
Contributor

bors bot commented May 28, 2020

Build succeeded:

@bors bors bot merged commit 64c02dc into FluxML:master May 28, 2020
@oxinabox oxinabox deleted the ox/chainrules_step1b branch May 28, 2020 15:52
@mcabbott mcabbott mentioned this pull request May 9, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants