-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow AD systems to register hooks so they can create new overloads in response to new rules #182
Conversation
it's difficult to see that i.e. since this PR does nothing it's difficult to see why it shouldn't wait til it does something / be part of a PR that actually adds functionality. I'm sure it is going to be useful... but it's pretty weird to add in isolation. |
Getting these annotation into ChainRules.jl, even if they do nothing open us up to solving either of the two proposals. It could wait for the first of those to be added. Psedudocode for how this would be used to do: #127 (comment)
|
@nickrobinson251 so you are declining this PR, and would like one of the proposals that need it to be implemented without this part way step? |
Persoanlly I would prefer us not to add this to the code until we are actually making use of it (i.e. til it does something). I'm happy for this to be it's own PR, that goes in alongside (/immediately before) other PRs that make it do something useful. Basically I just don't want us to have |
44d2e1f
to
f3e1637
Compare
Completely rewritten to actually use the macros now
end | ||
# @show fdef | ||
@eval $fdef | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because this code never uses the AST
I am now questioning if we even need to capture that.
If we don't then that is great because we can get rid of the FRULES
and RRULES
,
and just use methods(frule)
and methods(rrule)
instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We maybe could even get rid of the @frule
and @rrule
macros then, and the magic that detects the latest defined method, and go back to original proposal of using hooks attached to on_package_load
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using hooks attached to
on_package_load
What is this proposal? Is it written down somewhere? If the hooks are not triggered by @frule
/@rrule
macros, how are they triggered? Some Revise-style magic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is mentioned here
#127 (comment)
and yes it is the same hook in Base
that power's Revise:
Base.package_callbacks
it would be triggerd whenever a package is loaded.
Probably we would provide a manual ChainRulesCore.refresh
also for people to use in the REPL
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If @eval
will be called by the AD package, then we run into namespace issues. The correct behavior for dispatch-based AD is to overload the AD package's methods in the correct namespace outside the AD package, i.e. in the module where the rule is defined. Is this possible using this hook mechanism?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But what if Base
is replaced with another module and ...
uses code from that other module, and this module is not available in ForwardDiffZero?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eval
here evaluates the code in ForwardDiffZero
so it will complain if we use code not loaded in ForwardDiffZero iiuc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But what if Base is replaced with another module and ... uses code from that other module, and this module is not available in ForwardDiffZero?
the ...
is being written by the author of ForwardDiffZero
and i don't think has any reason to run anything other than functions from ForwardDiffZero
, or from ChainRuleCore
(which is loaded by ForwardDiffZero
).
Probably this is a reason not to include the AST as trying to use that for something in ...
will run into this.
I think its just not needed anyway.
The interesting bit is I guess is in the sig, for a Source to Source AD, like Zygote.
that is not generating overloads of op(overloaded_equiv.(args)....)
, but rather pullback(op, args...)
Maybe the types of the args would need the same opname = :($(parentmodule(op)).$(nameof(op)))
type escaping,
though even that wouldn't work because it would be a Symbol
.
But can it be a type directly and thus not need to be given a path in local scope?
Yes, that seems to work
julia> K = Base.Fix2
Base.Fix2
julia> eval(:(foo(x::$K) = x))
foo (generic function with 1 method)
julia> foo(Base.Fix2(+,1))
(::Base.Fix2{typeof(+),Int64}) (generic function with 1 method)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also do this trick to avoid qualifying names for the operation.
Though it has to be in call overload form for some reason.
(Still i guess that is fine since more generic if also overloading functors)
julia> struct Foo end
julia> K = typeof(Base.:+) # this would come in the sig tuple, just using Base as example
typeof(+)
julia> @eval (::$K)(::Foo, ::Foo) = 2
julia> +(Foo(), Foo())
2
julia> +(2,1)
3
julia> methods(+, (Foo, Foo))
# 1 method for generic function "+":
[1] +(::Foo, ::Foo) in Main at REPL[17]:1
julia> methods(+)
# 167 methods for generic function "+":
[1] +(x::Bool, z::Complex{Bool}) in Base at complex.jl:282
[2] +(x::Bool, y::Bool) in Base at bool.jl:96
[3] +(x::Bool) in Base at bool.jl:93
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am now using this trick in both demos.
Can this be considered resolved?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking great so far. No serious concerns. It would be great if @mohamed82008 could comment as he's going to be the first proper consumer of this stuff in ReverseDiff.
end | ||
# @show fdef | ||
@eval $fdef | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using hooks attached to
on_package_load
What is this proposal? Is it written down somewhere? If the hooks are not triggered by @frule
/@rrule
macros, how are they triggered? Some Revise-style magic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🎩
test/demos/forwarddiffzero.jl
Outdated
function derv(f, args...) | ||
duals = Dual.(args, one.(args)) | ||
return diff(f(duals...)) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
60 LOC for an AD system's pretty good going 👏
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
90 LOC for reverse mode. 😁
f6a7179
to
c632d35
Compare
No more macros. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming this gives us everything we need (for #127), this seems an amazingly efficient use of code. Good work!
src/rules.jl
Outdated
""" | ||
refresh_rules() = (refresh_rules(frule); refresh_rules(rrule)) | ||
function refresh_rules(rule_kind) | ||
already_done_world_age = last_refresh(rule_kind)[] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we check here if there are any hooks regisistered, and if not bail out early?
So that if someone has something that uses ChainRulesCore, but doesn't use any overloadinging AD package, then we don't spend the time going though the method table?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a couple of truly minor points. Otherwise, this looks really good.
I particularly like that you make propagate
a closure, so that a given Tracked
knows how to propagate the output of its pullback to its parents. Very clean.
When I wrote Nabla.jl I had a separate bit of control that handled accumulating things to the right places.
The include callback runs before the file is included, so it is not useful to us. I tested and the Package load callback runs after, so it is useful.
I had to remove the hook triggering on As i mentioned above we may like to use
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great. A few tiny comments from a quick read through. I'll finish reviewing tomorrow :)
Co-authored-by: Nick Robinson <[email protected]>
Co-authored-by: Nick Robinson <[email protected]>
Co-authored-by: Nick Robinson <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is nearly there. Just a bunch of typos / docs improvements + a suggestion for an extra test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Finally got round to reading the docs. Really good work on those too!
Co-authored-by: willtebbutt <[email protected]> Co-authored-by: Nick Robinson <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work.
Docs preview: https://www.juliadiff.org/ChainRulesCore.jl/previews/PR182/
Closes #127
Basically we require the rule definitions to be wrapped in macros then we scope up the AST, then evaluate the AST outselves then fineout what method was just defined.Then we remember what method goes to what AST in a rule list. And we trigger hooks that the AD system can register which will be used to generate a new AST that it will `eval` to define its own rule equivelent overloads. Or it might choose not to, it can use `method.sig` to determine if this is a rule it wants to deal with. When the AD is initially loaded it should trigger its hooks on the whole rule_list.TODO:
This PR adds macros to wrap all defintions of rules. Right now they do nothing at all. We may like to later tag a breaking change for when they are actually doing something and are thus required.But I am introducing them now so we can do things with them later.
in particular #127 just capturing the AST at the time it is created seems like a much simplier way to accomplish this goal.
Related: #44, e.g. this will allow us to setup splatting to work in
frule((partials...), f, args...)
.Once this is merged I will make the follow-ups to ChainRules and ChainRulesTestUtils.