Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow AD systems to register hooks so they can create new overloads in response to new rules #182

Merged
merged 43 commits into from
Aug 25, 2020
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
684661a
move using MulAddMacro to right place
oxinabox Jul 7, 2020
90a87ea
Add frule and rrule decorator macros
oxinabox Jul 7, 2020
0ad8235
Update src/rules.jl
oxinabox Jul 10, 2020
8de4ca9
Initial sketch of capturing the AST and feeding it to new rule hooks
oxinabox Jul 17, 2020
4476c4e
sort out API for overload generation
oxinabox Jul 21, 2020
8a1fb9c
add ForwardDiffZero as an API integration test
oxinabox Jul 21, 2020
83d7e8b
Revert "Add frule and rrule decorator macros"
oxinabox Jul 24, 2020
ffeb861
use refresh_rules either manually or autoamtically on pkg load / file…
oxinabox Jul 24, 2020
78fd24f
directly interpolate function type in
oxinabox Jul 24, 2020
0a741db
replace missed opname with op [fixme]
oxinabox Jul 25, 2020
a0129d0
don't handle multi-input
oxinabox Aug 17, 2020
d4efa8e
Add ReverseDiffZero demo
oxinabox Aug 18, 2020
3375791
remove excess new lines
oxinabox Aug 18, 2020
463166d
Update test/demos/reversediffzero.jl
oxinabox Aug 18, 2020
536dac3
Update test/demos/forwarddiffzero.jl
oxinabox Aug 18, 2020
4de2b6b
Update test/demos/reversediffzero.jl
oxinabox Aug 18, 2020
f4ed7c9
Apply suggestions from code review
oxinabox Aug 18, 2020
1872cb1
more comments
oxinabox Aug 18, 2020
37dda58
Apply suggestions from code review
oxinabox Aug 19, 2020
6a04dac
use paritial for all deriviative parts in demos
oxinabox Aug 19, 2020
dd083af
remove debug stuff
oxinabox Aug 19, 2020
562fe72
tweak comments etc
oxinabox Aug 19, 2020
0422e9c
start writing docs for using overload generation
oxinabox Aug 19, 2020
65e8c3d
working on docs
oxinabox Aug 20, 2020
6e58754
finish first pass at docs
oxinabox Aug 20, 2020
faa2087
more docs
oxinabox Aug 20, 2020
a909e2d
handle Unionall Signatures
oxinabox Aug 20, 2020
fb8cdf6
Stop refreshing rules on include_callback
oxinabox Aug 20, 2020
b8d1581
tweaks
oxinabox Aug 20, 2020
27a8592
remove type_constraint_equal
oxinabox Aug 21, 2020
36b6410
Update test/demos/reversediffzero.jl
oxinabox Aug 21, 2020
e87b845
Style and comment fixes
oxinabox Aug 21, 2020
1d91366
Don't export clear_new_rule_hooks!
oxinabox Aug 21, 2020
4ec4981
Update docs/make.jl
oxinabox Aug 21, 2020
baf4431
move comemnt
oxinabox Aug 21, 2020
ada4822
fix dotpoints in docs
oxinabox Aug 21, 2020
cfee703
fix clear rule hooks in tests
oxinabox Aug 21, 2020
3320fba
bump version
oxinabox Aug 21, 2020
fdca95c
Apply suggestions from code review
oxinabox Aug 22, 2020
d97535f
Update docs/src/autodiff/operator_overloading.md
oxinabox Aug 24, 2020
7d32509
More docs on generation
oxinabox Aug 24, 2020
ecd2bb6
test clear hooks
oxinabox Aug 24, 2020
eccb894
wrap up code review
oxinabox Aug 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ version = "0.8.2"

[[Documenter]]
deps = ["Base64", "Dates", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
git-tree-sha1 = "1c593d1efa27437ed9dd365d1143c594b563e138"
git-tree-sha1 = "fb1ff838470573adc15c71ba79f8d31328f035da"
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
version = "0.25.1"
version = "0.25.2"

[[DocumenterTools]]
deps = ["Base64", "DocStringExtensions", "Documenter", "FileWatching", "LibGit2", "Sass"]
Expand Down Expand Up @@ -78,9 +78,9 @@ version = "0.2.2"

[[Parsers]]
deps = ["Dates", "Test"]
git-tree-sha1 = "10134f2ee0b1978ae7752c41306e131a684e1f06"
git-tree-sha1 = "8077624b3c450b15c087944363606a6ba12f925e"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "1.0.7"
version = "1.0.10"

[[Pkg]]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
Expand All @@ -91,7 +91,7 @@ deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets"]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"

[[Random]]
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterTools = "35a29f4d-8980-5a13-9543-d66fff28ecb8"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"

[compat]
Documenter = "0.25"
Expand Down
5 changes: 5 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using ChainRulesCore
using Documenter
using DocumenterTools: Themes
using Markdown

DocMeta.setdocmeta!(
ChainRulesCore,
Expand Down Expand Up @@ -36,6 +37,10 @@ makedocs(
"Complex Numbers" => "complex.md",
"Deriving Array Rules" => "arrays.md",
"Debug Mode" => "debug_mode.md",
"Usage in an AD" => [
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
"Overview" => "autodiff/overview.md",
"Operator Overloading" => "autodiff/operator_overloading.md"
],
"Design" => [
"Many Differential Types" => "design/many_differentials.md",
],
Expand Down
8 changes: 8 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,16 @@ Pages = [
Private = false
```

## Ruleset Loading
```@autodocs
Modules = [ChainRulesCore]
Pages = ["ruleset_loading.jl"]
Private = false
```

## Internal
```@docs
ChainRulesCore.AbstractDifferential
ChainRulesCore.debug_mode
ChainRulesCore.clear_new_rule_hooks!
```
48 changes: 48 additions & 0 deletions docs/src/autodiff/operator_overloading.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Operator Overloading

The principle interface for using the operator overload generation method is [`on_new_rule`](@ref).
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
This function allows one to register a hook to be run every time a new rule is defined.
The hook receives a signature type-type as input, and generally will use `eval` to define
and overload of AD systems overloaded type.
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
For example, using the signature type `Tuple{typeof(+), Real, Real}` to define
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
`+(::DualNumber, ::DualNumber)` as calling the `frule` for `+`.
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
A signature type tuple always has the form:
`Tuple{typeof(operation), typeof{pos_arg1}, typeof{pos_arg2...}}`, where `pos_arg1` is the
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
first positional argument.
One can dispatch on the signature type, to make rules with argument types your AD does not support not call `eval`;
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
or more simply you can just use conditions for this.
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
The the hook is automatically triggered whenever a package is loaded.
oxinabox marked this conversation as resolved.
Show resolved Hide resolved

`refresh_rules`(@ref) is used to manually trigger the hook function on any new rules.
This is useful for example if new rules are define in the REPL, or if a package defining rules is modified.
(Revise.jl will not automatically trigger).

`clear_new_rule_hooks!`(@ref) clears all registered hooks.
It is useful to undo [`on_new_rule`] hook registration if you are iteratively developing your overload generation function.

## Examples

### ForwardDiffZero
The overload generation hook in this example is: `define_dual_overload`.

````@eval
using Markdown
Markdown.parse("""
```julia
$(read(joinpath(@__DIR__,"../../../test/demos/forwarddiffzero.jl"), String))
```
""")
````

### ReverseDiffZero
The overload generation hook in this example is: `define_tracked_overload`.

````@eval
using Markdown
Markdown.parse("""
```julia
$(read(joinpath(@__DIR__,"../../../test/demos/reversediffzero.jl"), String))
```
""")
````

oxinabox marked this conversation as resolved.
Show resolved Hide resolved
22 changes: 22 additions & 0 deletions docs/src/autodiff/overview.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Using ChainRules in your AutoDiff system
oxinabox marked this conversation as resolved.
Show resolved Hide resolved

This section is for authors of AD systems.
It assumes a pretty solid understanding of Julia, and of automatic differentiation.
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
It explains how to make use of ChainRule's rule sets,
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
to avoid having to code all your own AD primitives / custom sensitives.

There are 3 main ways to access ChainRules rule sets in your AutoDiff system.
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved

1. [Operation Overloading Generation](operator_overloading.html)
- This is primarily intended for operator overloading based AD systems which will generate overloads for primal function based for their overloaded types based on the existance of an `rrule`/`frule`.
- A source code generation based AD can also use this by overloading their transform generating function directly so as not to recursively generate a transform but to just return the rule.
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
- This does not play nice with Revise.jl, adding or modifying rules in loaded files will not be reflected until a manual refresh, and deleting rules will not be reflected at all.
2. Source code tranform based on inserting branches that check of `rrule`/`frule` return `nothing`
- if the `rrule`/`frule` returns a rule result then use it, if it return `nothing` then do normal AD path
- In theory type inference optimizes these branchs out; in practice it may not.
- This is a fairly simple Cassette overdub (or similar) of all calls, and is suitable for overloading based AD or source code transformation.
3. Source code transform based on `rrule`/`frule` method-table
- Always use `rrule`/`frule` iff and only if use the rules that exist, else generate normal AD path.
- This avoids having branches in your generated code.
- This requires maintaining your own back-edges
- This is pretty hard-code even by the standard of source code tranformations
10 changes: 7 additions & 3 deletions src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
module ChainRulesCore
using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize!
using MuladdMacro: @muladd

export frule, rrule
export @scalar_rule, @thunk
export canonicalize, extern, unthunk
export on_new_rule, refresh_rules, clear_new_rule_hooks! # generation tools
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
export frule, rrule # core function
export @scalar_rule, @thunk # defination helper macros
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
export canonicalize, extern, unthunk # differential operations
# differentials
export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Zero, AbstractZero, AbstractThunk
export NO_FIELDS

Expand All @@ -20,5 +23,6 @@ include("differential_arithmetic.jl")

include("rules.jl")
include("rule_definition_tools.jl")
include("ruleset_loading.jl")

end # module
2 changes: 0 additions & 2 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# These are some macros (and supporting functions) to make it easier to define rules.
using MuladdMacro: @muladd

"""
@scalar_rule(f(x₁, x₂, ...),
@setup(statement₁, statement₂, ...),
Expand Down
6 changes: 1 addition & 5 deletions src/rules.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
#####
##### `frule`/`rrule`
#####

"""
frule((Δf, Δx...), f, x...)

Expand Down Expand Up @@ -95,4 +91,4 @@ true

See also: [`frule`](@ref), [`@scalar_rule`](@ref)
"""
rrule(::Any, ::Vararg{Any}; kwargs...) = nothing
rrule(::Any, ::Vararg{Any}; kwargs...) = nothing
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
138 changes: 138 additions & 0 deletions src/ruleset_loading.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Infastructure to support generating overloads from rules.

function __init__()
# Need to refresh rules when a package is loaded
push!(Base.package_callbacks, pkgid -> refresh_rules())
end


oxinabox marked this conversation as resolved.
Show resolved Hide resolved
const NEW_RRULE_HOOKS = Function[]
const NEW_FRULE_HOOKS = Function[]
_hook_list(::typeof(rrule)) = NEW_RRULE_HOOKS
_hook_list(::typeof(frule)) = NEW_FRULE_HOOKS

"""
on_new_rule(hook, frule | rrule)

Register a `hook` function to run when new rules are defined.
The hook receives a signature type-type as input, and generally will use `eval` to define
and overload of AD systems overloaded type.
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
For example, using the signature type `Tuple{typeof(+), Real, Real}` to define
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
`+(::DualNumber, ::DualNumber)` as calling the `frule` for `+`.
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
A signature type tuple always has the form:
`Tuple{typeof(operation), typeof{pos_arg1}, typeof{pos_arg2...}}`, where `pos_arg1` is the
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
first positional argument
oxinabox marked this conversation as resolved.
Show resolved Hide resolved

The hooks are automatically run on new rules whenever a package is loaded.
They can be manually triggered by [`refresh_rules`](@ref).
When a hook is first registered with `on_new_rule` it is run on all existing rules.
"""
function on_new_rule(hook_fun, rule_kind)
# apply the hook to the existing rules
ret = map(_rule_list(rule_kind)) do method
sig = _primal_sig(rule_kind, method)
_safe_hook_fun(hook_fun, sig)
end

# register hook for new rules -- so all new rules get this function applied
push!(_hook_list(rule_kind), hook_fun)
return ret
end

"""
clear_new_rule_hooks!(frule|rrule)

Clears all hooks that were registered with corresponding [`on_new_rule`](@ref).
This is useful for while working interactively to define your rule generating hooks.
If you previously wrong an incorrect hook, you can use this to get rid of the old one.

!!! warning
This absolutely should not be used in a package, as it will break any other AD system
using the rule hooks that might happen to be loaded.
"""
clear_new_rule_hooks!(rule_kind) = empty!(_hook_list(rule_kind))


"""
_rule_list(frule | rrule)

Returns a list of all the methods of the currently defined rules of the given kind.
Excluding the fallback rule that returns `nothing` for every input.
"""
_rule_list(rule_kind) = (m for m in methods(rule_kind) if m.module != @__MODULE__)
# ^ The fallback rules are the only rules defined in ChainRules core so that is how we skip them.
nickrobinson251 marked this conversation as resolved.
Show resolved Hide resolved



const LAST_REFRESH_RRULE = Ref(0)
const LAST_REFRESH_FRULE = Ref(0)
last_refresh(::typeof(frule)) = LAST_REFRESH_FRULE
last_refresh(::typeof(rrule)) = LAST_REFRESH_RRULE

"""
refresh_rules()
refresh_rules(frule | rrule)

This triggers all [`on_new_rule`](@ref) hooks to run on any newly defined rules.
It is *automatically* run when ever a package is loaded.
It can also be manually called to run it directly, for example if a rule was defined
in the REPL or with-in the same file as the AD function.
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
"""
refresh_rules() = (refresh_rules(frule); refresh_rules(rrule))
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
function refresh_rules(rule_kind)
already_done_world_age = last_refresh(rule_kind)[]
for method in _rule_list(rule_kind)
_defined_world(method) < already_done_world_age && continue
sig = _primal_sig(rule_kind, method)
_trigger_new_rule_hooks(rule_kind, sig)
end

last_refresh(rule_kind)[] = _current_world()
return nothing
end

@static if VERSION >= v"1.2"
_current_world() = Base.get_world_counter()
_defined_world(method) = method.primary_world
else
_current_world() = ccall(:jl_get_world_counter, UInt, ())
_defined_world(method) = method.min_world
end

"""
_primal_sig(frule|rule, rule_method | rule_sig)

Returns the signature as a `Tuple{function_type, arg1_type, arg2_type,...}`.
"""
_primal_sig(rule_kind, method::Method) = _primal_sig(rule_kind, method.sig)
function _primal_sig(::typeof(frule), rule_sig::DataType)
@assert rule_sig.parameters[1] == typeof(frule)
# need to skip frule and the deriviative info, so starting from the 3rd
return Tuple{rule_sig.parameters[3:end]...}
end
function _primal_sig(::typeof(rrule), rule_sig::DataType)
@assert rule_sig.parameters[1] == typeof(rrule)
# need to skip rrule so starting from the 2rd
return Tuple{rule_sig.parameters[2:end]...}
end
function _primal_sig(rule_kind, rule_sig::UnionAll)
# This looks a lot like Base.unwrap_unionall and Base.rewrap_unionall, but using those
# seems not to work
p_sig = _primal_sig(rule_kind, rule_sig.body)
return UnionAll(rule_sig.var, p_sig)
end


function _trigger_new_rule_hooks(rule_kind, sig)
for hook_fun in _hook_list(rule_kind)
_safe_hook_fun(hook_fun, sig)
end
end

function _safe_hook_fun(hook_fun, sig)
try
hook_fun(sig)
catch err
@error "Error triggering hook" hook_fun sig exception=err
end
end
Loading