Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow AD systems to register hooks so they can create new overloads in response to new rules #182

Merged
merged 43 commits into from
Aug 25, 2020

Conversation

oxinabox
Copy link
Member

@oxinabox oxinabox commented Jul 7, 2020

Docs preview: https://www.juliadiff.org/ChainRulesCore.jl/previews/PR182/

Closes #127

Basically we require the rule definitions to be wrapped in macros then we scope up the AST, then evaluate the AST outselves then fineout what method was just defined. Then we remember what method goes to what AST in a rule list. And we trigger hooks that the AD system can register which will be used to generate a new AST that it will `eval` to define its own rule equivelent overloads. Or it might choose not to, it can use `method.sig` to determine if this is a rule it wants to deal with. When the AD is initially loaded it should trigger its hooks on the whole rule_list.

TODO:

  • exporting the API for this
  • docstrings
  • documentation
  • integration tests
  • unit tests

This PR adds macros to wrap all defintions of rules. Right now they do nothing at all. We may like to later tag a breaking change for when they are actually doing something and are thus required.

But I am introducing them now so we can do things with them later.
in particular #127 just capturing the AST at the time it is created seems like a much simplier way to accomplish this goal.

Related: #44, e.g. this will allow us to setup splatting to work in frule((partials...), f, args...).

Once this is merged I will make the follow-ups to ChainRules and ChainRulesTestUtils.

@nickrobinson251
Copy link
Contributor

nickrobinson251 commented Jul 7, 2020

it's difficult to see that @rrule function rrule(...) is the right pattern for this without seeing what useage you have in mind. Please can you put up the other PRs so I can see what use this pattern is going to have?

i.e. since this PR does nothing it's difficult to see why it shouldn't wait til it does something / be part of a PR that actually adds functionality. I'm sure it is going to be useful... but it's pretty weird to add in isolation.

@oxinabox
Copy link
Member Author

oxinabox commented Jul 7, 2020

Getting these annotation into ChainRules.jl, even if they do nothing open us up to solving either of the two proposals.
without needing to change anything in ChainRules.jl, just ChainRulesCore.

It could wait for the first of those to be added.

Psedudocode for how this would be used to do: #127 (comment)

macro @rrule(funcdef_expr)
    validate(funcdef_expr)
    return quote
        $funcdef_expr

        let
            ast = $(Meta.quot(funcdef_expr))
            _method = newest_method(rrule)
            rrule_def[_method] = ast
            trigger_hooks(rrule, ast, _method)
        end
    end
end

@nickrobinson251 nickrobinson251 added the pending-clear-need We are not certain we need this. So waiting for evidence to be presented label Jul 8, 2020
@oxinabox
Copy link
Member Author

@nickrobinson251 so you are declining this PR, and would like one of the proposals that need it to be implemented without this part way step?

src/rules.jl Outdated Show resolved Hide resolved
@nickrobinson251
Copy link
Contributor

nickrobinson251 commented Jul 10, 2020

are declining this PR

Persoanlly I would prefer us not to add this to the code until we are actually making use of it (i.e. til it does something). I'm happy for this to be it's own PR, that goes in alongside (/immediately before) other PRs that make it do something useful. Basically I just don't want us to have @frule function frule(....) code in master/docs until it's actually needed. But tbh i don't feel too strongly. If the follow-ups are going in soon, then yes just merge as is.

@oxinabox oxinabox force-pushed the ox/decoratormacros branch from 44d2e1f to f3e1637 Compare July 17, 2020 22:30
@oxinabox oxinabox changed the title Add @frule and @rrule decorator macros, that are currently identity tranforms WIP: Allow AD systems to register hooks so they can create new overloads in response to new rules Jul 17, 2020
@oxinabox oxinabox requested a review from nickrobinson251 July 17, 2020 22:38
@oxinabox oxinabox dismissed nickrobinson251’s stale review July 17, 2020 22:38

Completely rewritten to actually use the macros now

@oxinabox oxinabox removed the pending-clear-need We are not certain we need this. So waiting for evidence to be presented label Jul 18, 2020
end
# @show fdef
@eval $fdef
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because this code never uses the AST I am now questioning if we even need to capture that.
If we don't then that is great because we can get rid of the FRULES and RRULES,
and just use methods(frule) and methods(rrule) instead.

Copy link
Member Author

@oxinabox oxinabox Jul 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We maybe could even get rid of the @frule and @rrule macros then, and the magic that detects the latest defined method, and go back to original proposal of using hooks attached to on_package_load

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using hooks attached to on_package_load

What is this proposal? Is it written down somewhere? If the hooks are not triggered by @frule/@rrule macros, how are they triggered? Some Revise-style magic?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is mentioned here
#127 (comment)
and yes it is the same hook in Base that power's Revise:
Base.package_callbacks

it would be triggerd whenever a package is loaded.
Probably we would provide a manual ChainRulesCore.refresh also for people to use in the REPL

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If @eval will be called by the AD package, then we run into namespace issues. The correct behavior for dispatch-based AD is to overload the AD package's methods in the correct namespace outside the AD package, i.e. in the module where the rule is defined. Is this possible using this hook mechanism?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But what if Base is replaced with another module and ... uses code from that other module, and this module is not available in ForwardDiffZero?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eval here evaluates the code in ForwardDiffZero so it will complain if we use code not loaded in ForwardDiffZero iiuc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But what if Base is replaced with another module and ... uses code from that other module, and this module is not available in ForwardDiffZero?

the ... is being written by the author of ForwardDiffZero and i don't think has any reason to run anything other than functions from ForwardDiffZero, or from ChainRuleCore (which is loaded by ForwardDiffZero).
Probably this is a reason not to include the AST as trying to use that for something in ... will run into this.
I think its just not needed anyway.

The interesting bit is I guess is in the sig, for a Source to Source AD, like Zygote.
that is not generating overloads of op(overloaded_equiv.(args)....), but rather pullback(op, args...)
Maybe the types of the args would need the same opname = :($(parentmodule(op)).$(nameof(op))) type escaping,
though even that wouldn't work because it would be a Symbol.
But can it be a type directly and thus not need to be given a path in local scope?
Yes, that seems to work

julia> K = Base.Fix2
Base.Fix2

julia> eval(:(foo(x::$K) = x))
foo (generic function with 1 method)

julia> foo(Base.Fix2(+,1))
(::Base.Fix2{typeof(+),Int64}) (generic function with 1 method)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also do this trick to avoid qualifying names for the operation.
Though it has to be in call overload form for some reason.
(Still i guess that is fine since more generic if also overloading functors)

julia> struct Foo end

julia> K = typeof(Base.:+)  # this would come in the sig tuple, just using Base as example
typeof(+)

julia> @eval (::$K)(::Foo, ::Foo) = 2

julia> +(Foo(), Foo())
2

julia> +(2,1)
3

julia> methods(+, (Foo, Foo))
# 1 method for generic function "+":
[1] +(::Foo, ::Foo) in Main at REPL[17]:1

julia> methods(+)
# 167 methods for generic function "+":
[1] +(x::Bool, z::Complex{Bool}) in Base at complex.jl:282
[2] +(x::Bool, y::Bool) in Base at bool.jl:96
[3] +(x::Bool) in Base at bool.jl:93

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am now using this trick in both demos.
Can this be considered resolved?

Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking great so far. No serious concerns. It would be great if @mohamed82008 could comment as he's going to be the first proper consumer of this stuff in ReverseDiff.

src/rules.jl Outdated Show resolved Hide resolved
src/rules.jl Outdated Show resolved Hide resolved
test/demos/forwarddiffzero.jl Outdated Show resolved Hide resolved
test/demos/forwarddiffzero.jl Outdated Show resolved Hide resolved
test/demos/forwarddiffzero.jl Outdated Show resolved Hide resolved
src/rules.jl Outdated Show resolved Hide resolved
src/rules.jl Outdated Show resolved Hide resolved
src/rules.jl Outdated Show resolved Hide resolved
src/ChainRulesCore.jl Outdated Show resolved Hide resolved
src/ChainRulesCore.jl Outdated Show resolved Hide resolved
src/rules.jl Outdated Show resolved Hide resolved
end
# @show fdef
@eval $fdef
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using hooks attached to on_package_load

What is this proposal? Is it written down somewhere? If the hooks are not triggered by @frule/@rrule macros, how are they triggered? Some Revise-style magic?

Copy link
Contributor

@nickrobinson251 nickrobinson251 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎩

src/rules.jl Outdated Show resolved Hide resolved
test/demos/forwarddiffzero.jl Show resolved Hide resolved
function derv(f, args...)
duals = Dual.(args, one.(args))
return diff(f(duals...))
end
Copy link
Contributor

@nickrobinson251 nickrobinson251 Jul 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

60 LOC for an AD system's pretty good going 👏

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

90 LOC for reverse mode. 😁

@oxinabox oxinabox force-pushed the ox/decoratormacros branch from f6a7179 to c632d35 Compare July 24, 2020 19:00
@oxinabox
Copy link
Member Author

No more macros.
See new docstrings for what has changed

Copy link
Contributor

@nickrobinson251 nickrobinson251 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming this gives us everything we need (for #127), this seems an amazingly efficient use of code. Good work!

src/rules.jl Outdated Show resolved Hide resolved
src/rules.jl Outdated Show resolved Hide resolved
src/rules.jl Outdated Show resolved Hide resolved
src/rules.jl Outdated Show resolved Hide resolved
src/rules.jl Outdated Show resolved Hide resolved
src/rules.jl Outdated Show resolved Hide resolved
src/rules.jl Outdated
"""
refresh_rules() = (refresh_rules(frule); refresh_rules(rrule))
function refresh_rules(rule_kind)
already_done_world_age = last_refresh(rule_kind)[]
Copy link
Member Author

@oxinabox oxinabox Jul 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check here if there are any hooks regisistered, and if not bail out early?
So that if someone has something that uses ChainRulesCore, but doesn't use any overloadinging AD package, then we don't spend the time going though the method table?

test/demos/reversediffzero.jl Outdated Show resolved Hide resolved
test/demos/forwarddiffzero.jl Outdated Show resolved Hide resolved
test/demos/reversediffzero.jl Outdated Show resolved Hide resolved
test/demos/reversediffzero.jl Outdated Show resolved Hide resolved
Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a couple of truly minor points. Otherwise, this looks really good.

I particularly like that you make propagate a closure, so that a given Tracked knows how to propagate the output of its pullback to its parents. Very clean.

When I wrote Nabla.jl I had a separate bit of control that handled accumulating things to the right places.

test/demos/reversediffzero.jl Outdated Show resolved Hide resolved
test/demos/reversediffzero.jl Outdated Show resolved Hide resolved
The include callback runs before the file is included, so it is not useful to us.
I tested and the Package load callback runs after, so it is useful.
@oxinabox
Copy link
Member Author

I had to remove the hook triggering on include as that callback runs before the file is included.
I tested and confirmed the package load one runs after.

As i mentioned above we may like to use Revise.entr.
I think that triggers after Revise has done it's reloading.

#182 (comment)

Possible we should also hook to Revise.entr here.
I think that would be:

Revise.entr(refresh_rules, []; all=true, postpone=true)

Having that would make sure that modifications to rules are picked up to be regenerated,
but it would not handle deletions.
We would need to wrap it in Requires.jl.

I think this is best left for a follow up PR, and an issue should be openned about it.
Might also warrent a chat with Tim Holy, see if he has bright ideas about how we can handle deletions

@oxinabox oxinabox requested a review from mohamed82008 August 20, 2020 19:06
@oxinabox oxinabox requested a review from willtebbutt August 20, 2020 19:15
src/ruleset_loading.jl Outdated Show resolved Hide resolved
test/demos/reversediffzero.jl Outdated Show resolved Hide resolved
Copy link
Contributor

@nickrobinson251 nickrobinson251 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great. A few tiny comments from a quick read through. I'll finish reviewing tomorrow :)

docs/make.jl Outdated Show resolved Hide resolved
src/ChainRulesCore.jl Outdated Show resolved Hide resolved
src/ChainRulesCore.jl Outdated Show resolved Hide resolved
test/demos/forwarddiffzero.jl Outdated Show resolved Hide resolved
test/demos/reversediffzero.jl Outdated Show resolved Hide resolved
test/demos/reversediffzero.jl Outdated Show resolved Hide resolved
test/demos/reversediffzero.jl Outdated Show resolved Hide resolved
test/ruleset_loading.jl Outdated Show resolved Hide resolved
test/ruleset_loading.jl Outdated Show resolved Hide resolved
docs/src/autodiff/overview.md Outdated Show resolved Hide resolved
Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nearly there. Just a bunch of typos / docs improvements + a suggestion for an extra test.

docs/src/autodiff/operator_overloading.md Outdated Show resolved Hide resolved
docs/src/autodiff/operator_overloading.md Outdated Show resolved Hide resolved
docs/src/autodiff/operator_overloading.md Outdated Show resolved Hide resolved
docs/src/autodiff/operator_overloading.md Outdated Show resolved Hide resolved
docs/src/autodiff/operator_overloading.md Outdated Show resolved Hide resolved
src/ruleset_loading.jl Outdated Show resolved Hide resolved
src/ruleset_loading.jl Outdated Show resolved Hide resolved
src/ruleset_loading.jl Outdated Show resolved Hide resolved
src/ruleset_loading.jl Outdated Show resolved Hide resolved
test/ruleset_loading.jl Outdated Show resolved Hide resolved
Copy link
Contributor

@nickrobinson251 nickrobinson251 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finally got round to reading the docs. Really good work on those too!

docs/src/autodiff/overview.md Outdated Show resolved Hide resolved
docs/src/autodiff/overview.md Outdated Show resolved Hide resolved
Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Overload generation mode: support enumerating the rules
5 participants