From 684661a5734f2042b983686191d98a27770b9e32 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 7 Jul 2020 12:36:16 +0100 Subject: [PATCH 01/43] move using MulAddMacro to right place --- src/ChainRulesCore.jl | 1 + src/rule_definition_tools.jl | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 73b6231d3..3b436814c 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -1,5 +1,6 @@ module ChainRulesCore using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize! +using MuladdMacro: @muladd export frule, rrule export @scalar_rule, @thunk diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 83330e23b..2675c7598 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -1,6 +1,4 @@ # These are some macros (and supporting functions) to make it easier to define rules. -using MuladdMacro: @muladd - """ @scalar_rule(f(x₁, x₂, ...), @setup(statement₁, statement₂, ...), From 90a87ea2f45d876af8b2aee0db73a23c0b23371f Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 7 Jul 2020 13:54:30 +0100 Subject: [PATCH 02/43] Add frule and rrule decorator macros --- docs/src/index.md | 6 ++--- docs/src/writing_good_rules.md | 4 ++-- src/ChainRulesCore.jl | 2 +- src/differentials/abstract_zero.jl | 2 +- src/rule_definition_tools.jl | 10 ++++---- src/rules.jl | 38 ++++++++++++++++++++++++++++++ test/rules.jl | 12 ++++++---- 7 files changed, 57 insertions(+), 17 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 40991e199..26a832123 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -40,7 +40,7 @@ The rules are encoded as `frule`s and `rrule`s, for use in forward-mode and reve The `rrule` for some function `foo`, which takes the positional arguments `args` and keyword arguments `kwargs`, is written: ```julia -function rrule(::typeof(foo), args...; kwargs...) +@rrule function rrule(::typeof(foo), args...; kwargs...) ... return y, pullback end @@ -55,7 +55,7 @@ Almost always the _pullback_ will be declared locally within the `rrule`, and wi The `frule` is written: ```julia -function frule((Δself, Δargs...), ::typeof(foo), args...; kwargs...) +@frule function frule((Δself, Δargs...), ::typeof(foo), args...; kwargs...) ... return y, ∂Y end @@ -218,7 +218,7 @@ end ``` But because it is fused into frule we see it as part of: ```julia -function frule((Δself, Δargs...), ::typeof(foo), args...; kwargs...) +@frule function frule((Δself, Δargs...), ::typeof(foo), args...; kwargs...) ... return y, ∂y end diff --git a/docs/src/writing_good_rules.md b/docs/src/writing_good_rules.md index 23895f5de..18c8d0d85 100644 --- a/docs/src/writing_good_rules.md +++ b/docs/src/writing_good_rules.md @@ -42,7 +42,7 @@ Use named local functions for the `pullback` in an `rrule`. ```julia # good: -function rrule(::typeof(foo), x) +@rrule function rrule(::typeof(foo), x) Y = foo(x) function foo_pullback(x̄) return NO_FIELDS, bar(x̄) @@ -55,7 +55,7 @@ julia> rrule(foo, 2) ==# # bad: -function rrule(::typeof(foo), x) +@rrule function rrule(::typeof(foo), x) return foo(x), x̄ -> (NO_FIELDS, bar(x̄)) end #== output: diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 3b436814c..5cb93a3a1 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -3,7 +3,7 @@ using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, mate using MuladdMacro: @muladd export frule, rrule -export @scalar_rule, @thunk +export @frule, @rrule, @scalar_rule, @thunk export canonicalize, extern, unthunk export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Zero, AbstractZero, AbstractThunk export NO_FIELDS diff --git a/src/differentials/abstract_zero.jl b/src/differentials/abstract_zero.jl index 5790323ad..bfb040ea2 100644 --- a/src/differentials/abstract_zero.jl +++ b/src/differentials/abstract_zero.jl @@ -62,7 +62,7 @@ An optimization package making use of this might want to check for such a case. This mostly shows up as the derivative with respect to dimension, index, or size arguments. ``` - function rrule(fill, x, len::Int) + @rrule function rrule(fill, x, len::Int) y = fill(x, len) fill_pullback(ȳ) = (NO_FIELDS, @thunk(sum(Ȳ)), DoesNotExist()) return y, fill_pullback diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 2675c7598..aee4a3b76 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -10,7 +10,7 @@ A convenience macro that generates simple scalar forward or reverse rules using the provided partial derivatives. Specifically, generates the corresponding methods for `frule` and `rrule`: - function ChainRulesCore.frule((NO_FIELDS, Δx₁, Δx₂, ...), ::typeof(f), x₁::Number, x₂::Number, ...) + @frule function ChainRulesCore.frule((NO_FIELDS, Δx₁, Δx₂, ...), ::typeof(f), x₁::Number, x₂::Number, ...) Ω = f(x₁, x₂, ...) \$(statement₁, statement₂, ...) return Ω, ( @@ -20,7 +20,7 @@ methods for `frule` and `rrule`: ) end - function ChainRulesCore.rrule(::typeof(f), x₁::Number, x₂::Number, ...) + @rrule function ChainRulesCore.rrule(::typeof(f), x₁::Number, x₂::Number, ...) Ω = f(x₁, x₂, ...) \$(statement₁, statement₂, ...) return Ω, ((ΔΩ₁, ΔΩ₂, ...)) -> ( @@ -98,7 +98,7 @@ end _normalize_scalarrules_macro_input(call, maybe_setup, partials) returns (in order) the correctly escaped: - - `call` with out any type constraints + - `call` without any type constraints - `setup_stmts`: the content of `@setup` or `nothing` if that is not provided, - `inputs`: with all args having the constraints removed from call, or defaulting to `Number` @@ -164,7 +164,7 @@ function scalar_frule_expr(f, call, setup_stmts, inputs, partials) return quote # _ is the input derivative w.r.t. function internals. since we do not # allow closures/functors with @scalar_rule, it is always ignored - function ChainRulesCore.frule((_, $(Δs...)), ::typeof($f), $(inputs...)) + @frule function ChainRulesCore.frule((_, $(Δs...)), ::typeof($f), $(inputs...)) $(esc(:Ω)) = $call $(setup_stmts...) return $(esc(:Ω)), $pushforward_returns @@ -195,7 +195,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials) end return quote - function ChainRulesCore.rrule(::typeof($f), $(inputs...)) + @rrule function ChainRulesCore.rrule(::typeof($f), $(inputs...)) $(esc(:Ω)) = $call $(setup_stmts...) return $(esc(:Ω)), $pullback diff --git a/src/rules.jl b/src/rules.jl index a55e6d4bf..0b53221ab 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -45,6 +45,8 @@ julia> Δsincosx == (cos(x), -sin(x)) true ``` +When defining overloads they should be wrapped with the [`@frule`](@ref) macro. + See also: [`rrule`](@ref), [`@scalar_rule`](@ref) """ frule(::Any, ::Vararg{Any}; kwargs...) = nothing @@ -93,6 +95,42 @@ julia> hypot_pullback(1) == (NO_FIELDS, (x / hypot(x, y)), (y / hypot(x, y))) true ``` +When defining overloads they should be wrapped with the [`@rrule`](@ref) macro. + See also: [`frule`](@ref), [`@scalar_rule`](@ref) """ rrule(::Any, ::Vararg{Any}; kwargs...) = nothing + +""" + @frule(function ...) + +[`frule`](@ref) defining functions should be decorated with this macro. + +Example: +```julia +@frule function frule((Δself, Δargs...), ::typeof(foo), args...; kwargs...) + ... + return y, ∂Y +end +``` +""" +macro frule(expr) + return esc(expr) +end + +""" + @rrule(function ...) + +[`rrule`](@ref) defining functions should be decorated with this macro. + +Example: +```julia +@rrule function rrule(::typeof(foo), args...; kwargs...) + ... + return y, pullback +end +``` +""" +macro rrule(expr) + return esc(expr) +end diff --git a/test/rules.jl b/test/rules.jl index fba02b5e2..2440e6030 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -22,13 +22,13 @@ complex_times(x) = (1 + 2im) * x # hard to implement. varargs_function(x...) = sum(x) -function ChainRulesCore.frule(dargs, ::typeof(varargs_function), x...) +@frule function ChainRulesCore.frule(dargs, ::typeof(varargs_function), x...) Δx = Base.tail(dargs) return sum(x), sum(Δx) end mixed_vararg(x, y, z...) = x + y + sum(z) -function ChainRulesCore.frule( +@frule function ChainRulesCore.frule( dargs::Tuple{Any, Any, Any, Vararg}, ::typeof(mixed_vararg), x, y, z..., ) @@ -39,7 +39,7 @@ function ChainRulesCore.frule( end type_constraints(x::Int, y::Float64) = x + y -function ChainRulesCore.frule( +@frule function ChainRulesCore.frule( (_, Δx, Δy)::Tuple{Any, Int, Float64}, ::typeof(type_constraints), x::Int, y::Float64, ) @@ -47,7 +47,7 @@ function ChainRulesCore.frule( end mixed_vararg_type_constaint(x::Float64, y::Real, z::Vararg{Float64}) = x + y + sum(z) -function ChainRulesCore.frule( +@frule function ChainRulesCore.frule( dargs::Tuple{Any, Float64, Real, Vararg{Float64}}, ::typeof(mixed_vararg_type_constaint), x::Float64, y::Real, z::Vararg{Float64}, ) @@ -57,7 +57,9 @@ function ChainRulesCore.frule( return x + y + sum(z), Δx + Δy + sum(Δz) end -ChainRulesCore.frule(dargs, ::typeof(Core._apply), f, x...) = frule(dargs[2:end], f, x...) +@frule function ChainRulesCore.frule(dargs, ::typeof(Core._apply), f, x...) + return frule(dargs[2:end], f, x...) +end ####### From 0ad823502fda32ac260c82496250b4d8910f8fa4 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 10 Jul 2020 14:06:40 +0100 Subject: [PATCH 03/43] Update src/rules.jl --- src/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index 0b53221ab..15768f915 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -110,7 +110,7 @@ Example: ```julia @frule function frule((Δself, Δargs...), ::typeof(foo), args...; kwargs...) ... - return y, ∂Y + return y, ∂y end ``` """ From 8de4ca97829d3b3955409f73fa93d5937c06c35d Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 17 Jul 2020 23:29:51 +0100 Subject: [PATCH 04/43] Initial sketch of capturing the AST and feeding it to new rule hooks --- src/rules.jl | 66 +++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 63 insertions(+), 3 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 15768f915..bcb276fb0 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -114,8 +114,11 @@ Example: end ``` """ -macro frule(expr) - return esc(expr) +macro frule(ast) + return quote + $(esc(ast)) + $register_new_rule(frule, $(QuoteNode(ast))) + end end """ @@ -132,5 +135,62 @@ end ``` """ macro rrule(expr) - return esc(expr) + return quote + $(esc(ast)) + $register_new_rule(rrule, $(QuoteNode(ast))) + end end + +const FRULES = Vector{Pair{Method, Expr}}[] +const RRULES = Vector{Pair{Method, Expr}}[] +rule_list(::typeof(rrule)) = RRULES +rule_list(::typeof(frule)) = FRULES + +function register_new_rule(rule_kind, ast) + method = _just_defined_method(rule_kind) + push!(rule_list(rule_kind), method=>ast) + trigger_new_rule_hooks(rule_kind, method, ast) + return nothing +end + + +""" + _just_defined_method(f) + +Finds the method of `f` that was defined in the current world-age. +Errors if not found. +""" +function _just_defined_method(f) + @static if VERSION >= v"1.2" + current_world_age = Base.get_world_counter() + defined_world = :primary_world + else + current_world_age = ccall(:jl_get_world_counter, UInt, ()) + defined_world = :min_world + end + + for m in methods(f) + getproperty(m, defined_world) == current_world_age && return m + end + error("No method of `f` was defined in current world age") +end + + +NEW_RRULE_HOOKS = Function[] +NEW_FRULE_HOOKS = Function[] +hook_list(::typeof(rrule)) = NEW_RRULE_HOOKS +hook_list(::typeof(frule)) = NEW_FRULE_HOOKS + +function trigger_new_rule_hooks(rule_kind, method, ast) + for hook_fun in hook_list(rule_kind) + try + hook_fun(method, ast) + catch err + @warn "Error triggering hooks" hook_fun method ast exception=err + end + end +end + +on_new_rule(hook_fun, rule_kind) = push!(hook_list(rule_kind), hook_fun) +on_new_rrule(hook_fun) = on_new_rule(hook_fun, rrule) +on_new_frule(hook_fun) = on_new_rule(hook_fun, frule) From 4476c4e967af7a89cebe611057243c34b8a40129 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 21 Jul 2020 19:23:07 +0100 Subject: [PATCH 05/43] sort out API for overload generation --- src/ChainRulesCore.jl | 8 +++-- src/rules.jl | 69 +++++++++++++++++++++++++++++++------------ 2 files changed, 55 insertions(+), 22 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 5cb93a3a1..428498104 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -2,9 +2,11 @@ module ChainRulesCore using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize! using MuladdMacro: @muladd -export frule, rrule -export @frule, @rrule, @scalar_rule, @thunk -export canonicalize, extern, unthunk +export on_new_rule # generation tools +export frule, rrule # core function +export @frule, @rrule, @scalar_rule, @thunk # defination helper macros +export canonicalize, extern, unthunk # differnetial operations +# differentials export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Zero, AbstractZero, AbstractThunk export NO_FIELDS diff --git a/src/rules.jl b/src/rules.jl index bcb276fb0..dc12cb3d8 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -117,7 +117,7 @@ end macro frule(ast) return quote $(esc(ast)) - $register_new_rule(frule, $(QuoteNode(ast))) + $_register_new_rule(frule, $(QuoteNode(ast))) end end @@ -134,25 +134,46 @@ Example: end ``` """ -macro rrule(expr) +macro rrule(ast) return quote $(esc(ast)) - $register_new_rule(rrule, $(QuoteNode(ast))) + $_register_new_rule(rrule, $(QuoteNode(ast))) end end -const FRULES = Vector{Pair{Method, Expr}}[] -const RRULES = Vector{Pair{Method, Expr}}[] -rule_list(::typeof(rrule)) = RRULES -rule_list(::typeof(frule)) = FRULES +""" + _rule_list(frule | rrule) + +Returns a list of all rules currently defined rules of the given kind. +""" +_rule_list +_rule_list(::typeof(rrule)) = RRULES +_rule_list(::typeof(frule)) = FRULES +const SIGT = Type +const FRULES = Vector{Pair{SIGT, Expr}}() +const RRULES = Vector{Pair{SIGT, Expr}}() -function register_new_rule(rule_kind, ast) + +function _register_new_rule(rule_kind, ast) method = _just_defined_method(rule_kind) - push!(rule_list(rule_kind), method=>ast) - trigger_new_rule_hooks(rule_kind, method, ast) + rule_sig::SIGT = method.sig + sig = _primal_sig(rule_kind, rule_sig) + push!(_rule_list(rule_kind), sig=>ast) + _trigger_new_rule_hooks(rule_kind, sig, ast) return nothing end +function _primal_sig(::typeof(frule), rule_sig) + @assert rule_sig.parameters[1] == typeof(frule) + # need to skip frule and the deriviative info, so starting from the 3rd + return Tuple{rule_sig.parameters[3:end]...} +end + +function _primal_sig(::typeof(rrule), rule_sig) + @assert rule_sig.parameters[1] == typeof(rrule) + # need to skip rrule so starting from the 2rd + return Tuple{rule_sig.parameters[2:end]...} +end """ _just_defined_method(f) @@ -178,19 +199,29 @@ end NEW_RRULE_HOOKS = Function[] NEW_FRULE_HOOKS = Function[] -hook_list(::typeof(rrule)) = NEW_RRULE_HOOKS -hook_list(::typeof(frule)) = NEW_FRULE_HOOKS +_hook_list(::typeof(rrule)) = NEW_RRULE_HOOKS +_hook_list(::typeof(frule)) = NEW_FRULE_HOOKS -function trigger_new_rule_hooks(rule_kind, method, ast) - for hook_fun in hook_list(rule_kind) +function _trigger_new_rule_hooks(rule_kind, sig, ast) + for hook_fun in _hook_list(rule_kind) try - hook_fun(method, ast) + hook_fun(sig, ast) catch err - @warn "Error triggering hooks" hook_fun method ast exception=err + @warn "Error triggering hooks" hook_fun sig ast exception=err end end end -on_new_rule(hook_fun, rule_kind) = push!(hook_list(rule_kind), hook_fun) -on_new_rrule(hook_fun) = on_new_rule(hook_fun, rrule) -on_new_frule(hook_fun) = on_new_rule(hook_fun, frule) +""" + on_new_rule(((sig, ast)->eval...), frule | rrule) +""" +function on_new_rule(hook_fun, rule_kind) + # get all the existing rules + ret = map(_rule_list(rule_kind)) do (sig, ast) + hook_fun(sig, ast) + end + + # register hook for new rules + push!(_hook_list(rule_kind), hook_fun) + return ret +end From 8a1fb9c4343dbd9c06cea3ec7046869d312715cf Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 21 Jul 2020 19:25:37 +0100 Subject: [PATCH 06/43] add ForwardDiffZero as an API integration test --- test/demos/forwarddiffzero.jl | 76 +++++++++++++++++++++++++++++++++++ test/runtests.jl | 4 ++ 2 files changed, 80 insertions(+) create mode 100644 test/demos/forwarddiffzero.jl diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl new file mode 100644 index 000000000..f6441e00a --- /dev/null +++ b/test/demos/forwarddiffzero.jl @@ -0,0 +1,76 @@ +"The simplest viable forward mode a AD, only supports `Float64`" +module ForwardDiffZero + +using ChainRulesCore +using Test + +struct Dual <: Real + primal::Float64 + diff::Float64 +end + +primal(d::Dual) = d.primal +diff(d::Dual) = d.diff + +primal(d::Real) = d +diff(d::Real) = 0.0 + +# needed for ^ to work from having `*` defined +Base.to_power_type(x::Dual) = x + + +function define_dual_overload(sig, ast) + opT, argTs = Iterators.peel(sig.parameters) + fieldcount(opT) == 0 || return # not handling functors + all(Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops. + op = opT.instance + opname = :($(parentmodule(op)).$(nameof(op))) + N = length(sig.parameters) - 1 # skip the op + fdef = quote + function $opname(dual_args::Vararg{Union{Dual, Float64},$N}; kwargs...) + ȧrgs = (NO_FIELDS, diff.(dual_args)...) + args = ($opname, primal.(dual_args)...) + res = frule(ȧrgs, args...; kwargs...) + res === nothing && error("Apparently no rule for $($sig)), but we really thought there was, args=($args)") + y, ẏ = res + return Dual(y, ẏ) # if y, ẏ are not `Float64` this will error. + end + end + # @show fdef + @eval $fdef +end + +######################################### +# Initial rule setup +@scalar_rule x + y (One(), One()) +@scalar_rule x - y (One(), -1) + +on_new_rule(frule) do sig, ast + return define_dual_overload(sig, ast) +end + +# add a rule later also +@frule function ChainRulesCore.frule((_, Δx, Δy), ::typeof(*), x::Number, y::Number) + return (x * y, (Δx * y + x * Δy)) +end + +function derv(f, args...) + duals = Dual.(args, one.(args)) + return diff(f(duals...)) +end + +@testset "ForwardDiffZero" begin + foo(x) = x + x + @test derv(foo, 1.6) == 2 + + bar(x) = x + 2.1 * x + @test derv(bar, 1.2) == 3.1 + + baz(x) = 2.0 * x^2 + 3.0*x + 1.2 + @test derv(baz, 1.7) == 2*2.0*1.7 + 3.0 + + qux(x) = foo(x) + bar(x) + baz(x) + @test derv(qux, 1.7) == (2*2.0*1.7 + 3.0) + 3.1 + 2 +end + +end # module \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 5d58db5c9..0f0b80bc8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,4 +14,8 @@ using Test end include("rules.jl") + + @testset "demos" begin + include("demos/forwarddiffzero.jl") + end end From 83d7e8b439794ae2e9f99fab0699bd699d514be7 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 24 Jul 2020 20:07:46 +0100 Subject: [PATCH 07/43] Revert "Add frule and rrule decorator macros" This reverts commit 543aa4d8eb4d3925637993333d031d6a0c5d8c32. --- docs/src/index.md | 6 ++-- docs/src/writing_good_rules.md | 4 +-- src/ChainRulesCore.jl | 2 +- src/differentials/abstract_zero.jl | 2 +- src/rule_definition_tools.jl | 10 +++---- src/rules.jl | 44 ------------------------------ test/rules.jl | 12 ++++---- 7 files changed, 17 insertions(+), 63 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 26a832123..40991e199 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -40,7 +40,7 @@ The rules are encoded as `frule`s and `rrule`s, for use in forward-mode and reve The `rrule` for some function `foo`, which takes the positional arguments `args` and keyword arguments `kwargs`, is written: ```julia -@rrule function rrule(::typeof(foo), args...; kwargs...) +function rrule(::typeof(foo), args...; kwargs...) ... return y, pullback end @@ -55,7 +55,7 @@ Almost always the _pullback_ will be declared locally within the `rrule`, and wi The `frule` is written: ```julia -@frule function frule((Δself, Δargs...), ::typeof(foo), args...; kwargs...) +function frule((Δself, Δargs...), ::typeof(foo), args...; kwargs...) ... return y, ∂Y end @@ -218,7 +218,7 @@ end ``` But because it is fused into frule we see it as part of: ```julia -@frule function frule((Δself, Δargs...), ::typeof(foo), args...; kwargs...) +function frule((Δself, Δargs...), ::typeof(foo), args...; kwargs...) ... return y, ∂y end diff --git a/docs/src/writing_good_rules.md b/docs/src/writing_good_rules.md index 18c8d0d85..23895f5de 100644 --- a/docs/src/writing_good_rules.md +++ b/docs/src/writing_good_rules.md @@ -42,7 +42,7 @@ Use named local functions for the `pullback` in an `rrule`. ```julia # good: -@rrule function rrule(::typeof(foo), x) +function rrule(::typeof(foo), x) Y = foo(x) function foo_pullback(x̄) return NO_FIELDS, bar(x̄) @@ -55,7 +55,7 @@ julia> rrule(foo, 2) ==# # bad: -@rrule function rrule(::typeof(foo), x) +function rrule(::typeof(foo), x) return foo(x), x̄ -> (NO_FIELDS, bar(x̄)) end #== output: diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 428498104..e2b16669a 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -4,7 +4,7 @@ using MuladdMacro: @muladd export on_new_rule # generation tools export frule, rrule # core function -export @frule, @rrule, @scalar_rule, @thunk # defination helper macros +export @scalar_rule, @thunk # defination helper macros export canonicalize, extern, unthunk # differnetial operations # differentials export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Zero, AbstractZero, AbstractThunk diff --git a/src/differentials/abstract_zero.jl b/src/differentials/abstract_zero.jl index bfb040ea2..5790323ad 100644 --- a/src/differentials/abstract_zero.jl +++ b/src/differentials/abstract_zero.jl @@ -62,7 +62,7 @@ An optimization package making use of this might want to check for such a case. This mostly shows up as the derivative with respect to dimension, index, or size arguments. ``` - @rrule function rrule(fill, x, len::Int) + function rrule(fill, x, len::Int) y = fill(x, len) fill_pullback(ȳ) = (NO_FIELDS, @thunk(sum(Ȳ)), DoesNotExist()) return y, fill_pullback diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index aee4a3b76..2675c7598 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -10,7 +10,7 @@ A convenience macro that generates simple scalar forward or reverse rules using the provided partial derivatives. Specifically, generates the corresponding methods for `frule` and `rrule`: - @frule function ChainRulesCore.frule((NO_FIELDS, Δx₁, Δx₂, ...), ::typeof(f), x₁::Number, x₂::Number, ...) + function ChainRulesCore.frule((NO_FIELDS, Δx₁, Δx₂, ...), ::typeof(f), x₁::Number, x₂::Number, ...) Ω = f(x₁, x₂, ...) \$(statement₁, statement₂, ...) return Ω, ( @@ -20,7 +20,7 @@ methods for `frule` and `rrule`: ) end - @rrule function ChainRulesCore.rrule(::typeof(f), x₁::Number, x₂::Number, ...) + function ChainRulesCore.rrule(::typeof(f), x₁::Number, x₂::Number, ...) Ω = f(x₁, x₂, ...) \$(statement₁, statement₂, ...) return Ω, ((ΔΩ₁, ΔΩ₂, ...)) -> ( @@ -98,7 +98,7 @@ end _normalize_scalarrules_macro_input(call, maybe_setup, partials) returns (in order) the correctly escaped: - - `call` without any type constraints + - `call` with out any type constraints - `setup_stmts`: the content of `@setup` or `nothing` if that is not provided, - `inputs`: with all args having the constraints removed from call, or defaulting to `Number` @@ -164,7 +164,7 @@ function scalar_frule_expr(f, call, setup_stmts, inputs, partials) return quote # _ is the input derivative w.r.t. function internals. since we do not # allow closures/functors with @scalar_rule, it is always ignored - @frule function ChainRulesCore.frule((_, $(Δs...)), ::typeof($f), $(inputs...)) + function ChainRulesCore.frule((_, $(Δs...)), ::typeof($f), $(inputs...)) $(esc(:Ω)) = $call $(setup_stmts...) return $(esc(:Ω)), $pushforward_returns @@ -195,7 +195,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials) end return quote - @rrule function ChainRulesCore.rrule(::typeof($f), $(inputs...)) + function ChainRulesCore.rrule(::typeof($f), $(inputs...)) $(esc(:Ω)) = $call $(setup_stmts...) return $(esc(:Ω)), $pullback diff --git a/src/rules.jl b/src/rules.jl index dc12cb3d8..6031b7dde 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -45,8 +45,6 @@ julia> Δsincosx == (cos(x), -sin(x)) true ``` -When defining overloads they should be wrapped with the [`@frule`](@ref) macro. - See also: [`rrule`](@ref), [`@scalar_rule`](@ref) """ frule(::Any, ::Vararg{Any}; kwargs...) = nothing @@ -95,52 +93,10 @@ julia> hypot_pullback(1) == (NO_FIELDS, (x / hypot(x, y)), (y / hypot(x, y))) true ``` -When defining overloads they should be wrapped with the [`@rrule`](@ref) macro. - See also: [`frule`](@ref), [`@scalar_rule`](@ref) """ rrule(::Any, ::Vararg{Any}; kwargs...) = nothing -""" - @frule(function ...) - -[`frule`](@ref) defining functions should be decorated with this macro. - -Example: -```julia -@frule function frule((Δself, Δargs...), ::typeof(foo), args...; kwargs...) - ... - return y, ∂y -end -``` -""" -macro frule(ast) - return quote - $(esc(ast)) - $_register_new_rule(frule, $(QuoteNode(ast))) - end -end - -""" - @rrule(function ...) - -[`rrule`](@ref) defining functions should be decorated with this macro. - -Example: -```julia -@rrule function rrule(::typeof(foo), args...; kwargs...) - ... - return y, pullback -end -``` -""" -macro rrule(ast) - return quote - $(esc(ast)) - $_register_new_rule(rrule, $(QuoteNode(ast))) - end -end - """ _rule_list(frule | rrule) diff --git a/test/rules.jl b/test/rules.jl index 2440e6030..fba02b5e2 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -22,13 +22,13 @@ complex_times(x) = (1 + 2im) * x # hard to implement. varargs_function(x...) = sum(x) -@frule function ChainRulesCore.frule(dargs, ::typeof(varargs_function), x...) +function ChainRulesCore.frule(dargs, ::typeof(varargs_function), x...) Δx = Base.tail(dargs) return sum(x), sum(Δx) end mixed_vararg(x, y, z...) = x + y + sum(z) -@frule function ChainRulesCore.frule( +function ChainRulesCore.frule( dargs::Tuple{Any, Any, Any, Vararg}, ::typeof(mixed_vararg), x, y, z..., ) @@ -39,7 +39,7 @@ mixed_vararg(x, y, z...) = x + y + sum(z) end type_constraints(x::Int, y::Float64) = x + y -@frule function ChainRulesCore.frule( +function ChainRulesCore.frule( (_, Δx, Δy)::Tuple{Any, Int, Float64}, ::typeof(type_constraints), x::Int, y::Float64, ) @@ -47,7 +47,7 @@ type_constraints(x::Int, y::Float64) = x + y end mixed_vararg_type_constaint(x::Float64, y::Real, z::Vararg{Float64}) = x + y + sum(z) -@frule function ChainRulesCore.frule( +function ChainRulesCore.frule( dargs::Tuple{Any, Float64, Real, Vararg{Float64}}, ::typeof(mixed_vararg_type_constaint), x::Float64, y::Real, z::Vararg{Float64}, ) @@ -57,9 +57,7 @@ mixed_vararg_type_constaint(x::Float64, y::Real, z::Vararg{Float64}) = x + y + s return x + y + sum(z), Δx + Δy + sum(Δz) end -@frule function ChainRulesCore.frule(dargs, ::typeof(Core._apply), f, x...) - return frule(dargs[2:end], f, x...) -end +ChainRulesCore.frule(dargs, ::typeof(Core._apply), f, x...) = frule(dargs[2:end], f, x...) ####### From ffeb861969794c229b8b6a7c889166173f2e75d1 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 24 Jul 2020 23:17:00 +0100 Subject: [PATCH 08/43] use refresh_rules either manually or autoamtically on pkg load / file include in order to generate new rules --- src/ChainRulesCore.jl | 2 +- src/rules.jl | 144 ++++++++++++++++++++-------------- test/demos/forwarddiffzero.jl | 15 ++-- 3 files changed, 92 insertions(+), 69 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index e2b16669a..0d1515365 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -2,7 +2,7 @@ module ChainRulesCore using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize! using MuladdMacro: @muladd -export on_new_rule # generation tools +export on_new_rule, refresh_rules # generation tools export frule, rrule # core function export @scalar_rule, @thunk # defination helper macros export canonicalize, extern, unthunk # differnetial operations diff --git a/src/rules.jl b/src/rules.jl index 6031b7dde..241ae44d6 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -97,87 +97,109 @@ See also: [`frule`](@ref), [`@scalar_rule`](@ref) """ rrule(::Any, ::Vararg{Any}; kwargs...) = nothing + +####################################################################### +# Infastructure to support generating overloads from rules. + +const NEW_RRULE_HOOKS = Function[] +const NEW_FRULE_HOOKS = Function[] +_hook_list(::typeof(rrule)) = NEW_RRULE_HOOKS +_hook_list(::typeof(frule)) = NEW_FRULE_HOOKS + +""" + on_new_rule(sig->eval...), frule | rrule) +""" +function on_new_rule(hook_fun, rule_kind) + # get all the existing rules + ret = map(_rule_list(rule_kind)) do method + sig = _primal_sig(rule_kind, method) + _safe_hook_fun(hook_fun, sig) + end + + # register hook for new rules + push!(_hook_list(rule_kind), hook_fun) + return ret +end + +function __init__() + push!(Base.package_callbacks, pkgid -> refresh_rules()) + push!(Base.include_callbacks, (mod, filename) -> refresh_rules()) +end + + """ _rule_list(frule | rrule) -Returns a list of all rules currently defined rules of the given kind. +Returns a list of all the methods of the currently defined rules of the given kind. +Excluding the fallback rule that returns `nothing` for every input. +""" +_rule_list(rule_kind) = (m for m in methods(rule_kind) if m.module != @__MODULE__) +# ^ The fallback rules are the only rules defined in ChainRules core so that is how we skip them. + + + +const LAST_REFRESH_RRULE = Ref(0) +const LAST_REFRESH_FRULE = Ref(0) +last_refresh(::typeof(frule)) = LAST_REFRESH_FRULE +last_refresh(::typeof(rrule)) = LAST_REFRESH_RRULE + +""" + refresh_rules(); refresh_rules(frule); refresh_rules(rrule) + +This triggers all [`on_new_rule`](@ref) hooks to run on any newly defined rules. +It is *automatically* run when ever a package is loaded, or a file is `include`d. +It can also be manually called to run it directly, for example if a rule was defined +in the REPL or with-in the same file as the AD function. """ -_rule_list -_rule_list(::typeof(rrule)) = RRULES -_rule_list(::typeof(frule)) = FRULES -const SIGT = Type -const FRULES = Vector{Pair{SIGT, Expr}}() -const RRULES = Vector{Pair{SIGT, Expr}}() - - -function _register_new_rule(rule_kind, ast) - method = _just_defined_method(rule_kind) - rule_sig::SIGT = method.sig - sig = _primal_sig(rule_kind, rule_sig) - push!(_rule_list(rule_kind), sig=>ast) - _trigger_new_rule_hooks(rule_kind, sig, ast) +refresh_rules() = (refresh_rules(frule); refresh_rules(rrule)) +function refresh_rules(rule_kind) + already_done_world_age = last_refresh(rule_kind)[] + for method in _rule_list(rule_kind) + _defined_world(method) < already_done_world_age && continue + sig = _primal_sig(rule_kind, method) + _trigger_new_rule_hooks(rule_kind, sig) + end + + last_refresh(rule_kind)[] = _current_world() return nothing end -function _primal_sig(::typeof(frule), rule_sig) +@static if VERSION >= v"1.2" + _current_world() = Base.get_world_counter() + _defined_world(method) = method.primary_world +else + _current_world() = ccall(:jl_get_world_counter, UInt, ()) + _defined_world(method) = method.min_world +end + +""" + _primal_sig(frule|rule, rule_method | rule_sig) + +Returns the signature as a `Tuple{function_type, arg1_type, arg2_type,...}`. +""" +_primal_sig(rule_kind, method::Method) = _primal_sig(rule_kind, method.sig) +function _primal_sig(::typeof(frule), rule_sig::Type) @assert rule_sig.parameters[1] == typeof(frule) # need to skip frule and the deriviative info, so starting from the 3rd return Tuple{rule_sig.parameters[3:end]...} end -function _primal_sig(::typeof(rrule), rule_sig) +function _primal_sig(::typeof(rrule), rule_sig::Type) @assert rule_sig.parameters[1] == typeof(rrule) # need to skip rrule so starting from the 2rd return Tuple{rule_sig.parameters[2:end]...} end -""" - _just_defined_method(f) - -Finds the method of `f` that was defined in the current world-age. -Errors if not found. -""" -function _just_defined_method(f) - @static if VERSION >= v"1.2" - current_world_age = Base.get_world_counter() - defined_world = :primary_world - else - current_world_age = ccall(:jl_get_world_counter, UInt, ()) - defined_world = :min_world - end - - for m in methods(f) - getproperty(m, defined_world) == current_world_age && return m - end - error("No method of `f` was defined in current world age") -end - - -NEW_RRULE_HOOKS = Function[] -NEW_FRULE_HOOKS = Function[] -_hook_list(::typeof(rrule)) = NEW_RRULE_HOOKS -_hook_list(::typeof(frule)) = NEW_FRULE_HOOKS - -function _trigger_new_rule_hooks(rule_kind, sig, ast) +function _trigger_new_rule_hooks(rule_kind, sig) for hook_fun in _hook_list(rule_kind) - try - hook_fun(sig, ast) - catch err - @warn "Error triggering hooks" hook_fun sig ast exception=err - end + _safe_hook_fun(hook_fun, sig) end end -""" - on_new_rule(((sig, ast)->eval...), frule | rrule) -""" -function on_new_rule(hook_fun, rule_kind) - # get all the existing rules - ret = map(_rule_list(rule_kind)) do (sig, ast) - hook_fun(sig, ast) +function _safe_hook_fun(hook_fun, sig) + try + hook_fun(sig) + catch err + @error "Error triggering hook" hook_fun sig exception=err end - - # register hook for new rules - push!(_hook_list(rule_kind), hook_fun) - return ret end diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl index f6441e00a..aa7be5d13 100644 --- a/test/demos/forwarddiffzero.jl +++ b/test/demos/forwarddiffzero.jl @@ -19,7 +19,7 @@ diff(d::Real) = 0.0 Base.to_power_type(x::Dual) = x -function define_dual_overload(sig, ast) +function define_dual_overload(sig) opT, argTs = Iterators.peel(sig.parameters) fieldcount(opT) == 0 || return # not handling functors all(Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops. @@ -36,8 +36,8 @@ function define_dual_overload(sig, ast) return Dual(y, ẏ) # if y, ẏ are not `Float64` this will error. end end - # @show fdef - @eval $fdef + #@show fdef + eval(fdef) end ######################################### @@ -45,15 +45,16 @@ end @scalar_rule x + y (One(), One()) @scalar_rule x - y (One(), -1) -on_new_rule(frule) do sig, ast - return define_dual_overload(sig, ast) -end +on_new_rule(define_dual_overload, frule) # add a rule later also -@frule function ChainRulesCore.frule((_, Δx, Δy), ::typeof(*), x::Number, y::Number) +function ChainRulesCore.frule((_, Δx, Δy), ::typeof(*), x::Number, y::Number) return (x * y, (Δx * y + x * Δy)) end +# Manual refresh needed as new rule added in same file as AD after the `on_new_rule` call +refresh_rules(); + function derv(f, args...) duals = Dual.(args, one.(args)) return diff(f(duals...)) From 78fd24f63e3227e4d8782cc907f182b9fad9c624 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 24 Jul 2020 23:20:37 +0100 Subject: [PATCH 09/43] directly interpolate function type in --- test/demos/forwarddiffzero.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl index aa7be5d13..cadeef289 100644 --- a/test/demos/forwarddiffzero.jl +++ b/test/demos/forwarddiffzero.jl @@ -23,11 +23,12 @@ function define_dual_overload(sig) opT, argTs = Iterators.peel(sig.parameters) fieldcount(opT) == 0 || return # not handling functors all(Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops. - op = opT.instance - opname = :($(parentmodule(op)).$(nameof(op))) + N = length(sig.parameters) - 1 # skip the op fdef = quote - function $opname(dual_args::Vararg{Union{Dual, Float64},$N}; kwargs...) + # we use the function call overloading form as it lets us avoid namespacing issues + # as we can directly interpolate the function type into to the AST. + function (::$opT)(dual_args::Vararg{Union{Dual, Float64}, $N}; kwargs...) ȧrgs = (NO_FIELDS, diff.(dual_args)...) args = ($opname, primal.(dual_args)...) res = frule(ȧrgs, args...; kwargs...) From 0a741db15f8682f67006ee76bd423fd1f8fce849 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Sat, 25 Jul 2020 10:54:02 +0100 Subject: [PATCH 10/43] replace missed opname with op [fixme] --- test/demos/forwarddiffzero.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl index cadeef289..efa1b2efc 100644 --- a/test/demos/forwarddiffzero.jl +++ b/test/demos/forwarddiffzero.jl @@ -28,9 +28,9 @@ function define_dual_overload(sig) fdef = quote # we use the function call overloading form as it lets us avoid namespacing issues # as we can directly interpolate the function type into to the AST. - function (::$opT)(dual_args::Vararg{Union{Dual, Float64}, $N}; kwargs...) + function (op::$opT)(dual_args::Vararg{Union{Dual, Float64}, $N}; kwargs...) ȧrgs = (NO_FIELDS, diff.(dual_args)...) - args = ($opname, primal.(dual_args)...) + args = (op, primal.(dual_args)...) res = frule(ȧrgs, args...; kwargs...) res === nothing && error("Apparently no rule for $($sig)), but we really thought there was, args=($args)") y, ẏ = res From a0129d0070e300047970bfc83e38f46ea72da323 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 17 Aug 2020 16:50:00 +0100 Subject: [PATCH 11/43] don't handle multi-input --- test/demos/forwarddiffzero.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl index efa1b2efc..80bf07ac3 100644 --- a/test/demos/forwarddiffzero.jl +++ b/test/demos/forwarddiffzero.jl @@ -56,8 +56,8 @@ end # Manual refresh needed as new rule added in same file as AD after the `on_new_rule` call refresh_rules(); -function derv(f, args...) - duals = Dual.(args, one.(args)) +function derv(f, arg) + duals = Dual(arg, one(arg)) return diff(f(duals...)) end From d4efa8e6ad93f0a446a878bc60229b8cca8420f0 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 18 Aug 2020 16:49:10 +0100 Subject: [PATCH 12/43] Add ReverseDiffZero demo --- test/demos/forwarddiffzero.jl | 6 ++ test/demos/reversediffzero.jl | 133 ++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 140 insertions(+) create mode 100644 test/demos/reversediffzero.jl diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl index 80bf07ac3..2154ccd9a 100644 --- a/test/demos/forwarddiffzero.jl +++ b/test/demos/forwarddiffzero.jl @@ -73,6 +73,12 @@ end qux(x) = foo(x) + bar(x) + baz(x) @test derv(qux, 1.7) == (2*2.0*1.7 + 3.0) + 3.1 + 2 + + function quux(x) + y = 2.0*x + 3.0*x + return 4.0*y + 5.0*y + end + @test derv(quux, 11.1) == 4*(2+3) + 5*(2+3) end end # module \ No newline at end of file diff --git a/test/demos/reversediffzero.jl b/test/demos/reversediffzero.jl new file mode 100644 index 000000000..1098e3736 --- /dev/null +++ b/test/demos/reversediffzero.jl @@ -0,0 +1,133 @@ +"The simplest viable reverse mode a AD, only supports `Float64`" +module ReverseDiffZero + +using ChainRulesCore +using Test + +struct Tracked <: Real + propagate::Function + primal::Float64 + tape::Vector{Any} # a reference to a shared tape + grad::Base.RefValue{Float64} # current accumulated gradient +end + +"An intermediate value, a Branch in Nabla terms." +function Tracked(propagate, primal, tape) + v = Tracked(propagate, primal, tape, Ref(zero(primal))) + push!(tape, v) + return v +end + +"An input, a Leaf in Nabla terms. No inputs of its on to propagate to." +function Tracked(primal, tape) + # don't actually need to put these on the tape, since they don't need to + # propagate anything + return Tracked(_->nothing, primal, tape, Ref(zero(primal))) +end + +primal(d::Tracked) = d.primal +primal(d) = d + +grad(d::Tracked) = d.grad[] +grad(d) = nothing + +tape(d::Tracked) = d.tape +tape(d) = nothing + +# we have many inputs grab the tape from the first one that is tracked +get_tape(ds) = something(tape.(ds)...) + +# propagate the currently stored gradient back to my inputs. +propagate!(d) = nothing +propagate!(d::Tracked) = d.propagate(d.grad[]) + +# Accumulate gradient, if the value is being tracked. +accum!(d, x̄) = nothing +accum!(d::Tracked, x̄) = d.grad[] += x̄ + +function define_tracked_overload(sig) + opT, argTs = Iterators.peel(sig.parameters) + fieldcount(opT) == 0 || return # not handling functors + all(Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops. + + N = length(sig.parameters) - 1 # skip the op + fdef = quote + # we use the function call overloading form as it lets us avoid namespacing issues + # as we can directly interpolate the function type into to the AST. + function (op::$opT)(tracked_args::Vararg{Union{Tracked, Float64}, $N}; kwargs...) + args = (op, primal.(tracked_args)...) + res = rrule(args...; kwargs...) + res === nothing && error("Apparently no rule for $($sig)), but we really thought there was, args=($args)") + y, y_pullback = res + t = get_tape(tracked_args) + y_tracked = Tracked(y, t) do ȳ + # pull this gradient back and propagate it to the inputs gradient stores + _, ārgs = Iterators.peel(y_pullback(ȳ)) + accum!.(tracked_args, ārgs) + end + return y_tracked + end + end + eval(fdef) +end + +function derv(f, args::Vararg; kwargs...) + the_tape = Vector{Any}() + tracked_inputs = Tracked.(args, Ref(the_tape)) + tracked_output = f(tracked_inputs...; kwargs...) + @assert tape(tracked_output) === the_tape + + out = primal(tracked_output) + function back(ōut) + accum!(tracked_output, ōut) + for op in reverse(the_tape) + # by going down the tape backwards we know we will + # have accumulated its gradient fully + propagate!(op) + end + return grad.(tracked_inputs) + end + return back(one(out)) +end + +# needed for ^ to work from having `*` defined +Base.to_power_type(x::Tracked) = x + +######################################### +# Initial rule setup +@scalar_rule x + y (One(), One()) +@scalar_rule x - y (One(), -1) + +on_new_rule(define_tracked_overload, rrule) + +# add a rule later also +function ChainRulesCore.rrule(::typeof(*), x::Number, y::Number) + function times_pullback(ΔΩ) + return (NO_FIELDS, @thunk(ΔΩ * y'), @thunk(x' * ΔΩ)) + end + return x * y, times_pullback +end + +# Manual refresh needed as new rule added in same file as AD after the `on_new_rule` call +refresh_rules(); + +@testset "ReversedDiffZero" begin + foo(x) = x + x + @test derv(foo, 1.6) == (2.0,) + + bar(x) = x + 2.1 * x + @test derv(bar, 1.2) == (3.1,) + + baz(x) = 2.0 * x^2 + 3.0*x + 1.2 + @test derv(baz, 1.7) == (2*2.0*1.7 + 3.0,) + + qux(x) = foo(x) + bar(x) + baz(x) + @test derv(qux, 1.7) == ((2*2.0*1.7 + 3.0) + 3.1 + 2,) + + function quux(x) + y = 2.0*x + 3.0*x + return 4.0*y + 5.0*y + end + @test derv(quux, 11.1) == (4*(2+3) + 5*(2+3),) +end +end # module \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 0f0b80bc8..90c997e04 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,5 +17,6 @@ using Test @testset "demos" begin include("demos/forwarddiffzero.jl") + include("demos/reversediffzero.jl") end end From 3375791d6bd92b4efb436323092344eb9f7001d1 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 18 Aug 2020 16:51:37 +0100 Subject: [PATCH 13/43] remove excess new lines --- test/demos/reversediffzero.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/demos/reversediffzero.jl b/test/demos/reversediffzero.jl index 1098e3736..f1db2a81f 100644 --- a/test/demos/reversediffzero.jl +++ b/test/demos/reversediffzero.jl @@ -1,6 +1,5 @@ "The simplest viable reverse mode a AD, only supports `Float64`" module ReverseDiffZero - using ChainRulesCore using Test From 463166d00f5c74680f6109519af79f5515272196 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 18 Aug 2020 18:01:13 +0100 Subject: [PATCH 14/43] Update test/demos/reversediffzero.jl Co-authored-by: mattBrzezinski --- test/demos/reversediffzero.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/demos/reversediffzero.jl b/test/demos/reversediffzero.jl index f1db2a81f..77fb9e3a0 100644 --- a/test/demos/reversediffzero.jl +++ b/test/demos/reversediffzero.jl @@ -37,8 +37,8 @@ tape(d) = nothing get_tape(ds) = something(tape.(ds)...) # propagate the currently stored gradient back to my inputs. -propagate!(d) = nothing propagate!(d::Tracked) = d.propagate(d.grad[]) +propagate!(d) = nothing # Accumulate gradient, if the value is being tracked. accum!(d, x̄) = nothing @@ -129,4 +129,4 @@ refresh_rules(); end @test derv(quux, 11.1) == (4*(2+3) + 5*(2+3),) end -end # module \ No newline at end of file +end # module From 536dac39eac0baceb02edbd1e7f0223a4b8625e3 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 18 Aug 2020 18:01:28 +0100 Subject: [PATCH 15/43] Update test/demos/forwarddiffzero.jl Co-authored-by: mattBrzezinski --- test/demos/forwarddiffzero.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl index 2154ccd9a..f4bc4bf4d 100644 --- a/test/demos/forwarddiffzero.jl +++ b/test/demos/forwarddiffzero.jl @@ -81,4 +81,4 @@ end @test derv(quux, 11.1) == 4*(2+3) + 5*(2+3) end -end # module \ No newline at end of file +end # module From 4de2b6b521316bff0d0926139bd0184e18ca78ec Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 18 Aug 2020 18:01:52 +0100 Subject: [PATCH 16/43] Update test/demos/reversediffzero.jl Co-authored-by: mattBrzezinski --- test/demos/reversediffzero.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/demos/reversediffzero.jl b/test/demos/reversediffzero.jl index 77fb9e3a0..05f9b12ee 100644 --- a/test/demos/reversediffzero.jl +++ b/test/demos/reversediffzero.jl @@ -41,8 +41,8 @@ propagate!(d::Tracked) = d.propagate(d.grad[]) propagate!(d) = nothing # Accumulate gradient, if the value is being tracked. -accum!(d, x̄) = nothing accum!(d::Tracked, x̄) = d.grad[] += x̄ +accum!(d, x̄) = nothing function define_tracked_overload(sig) opT, argTs = Iterators.peel(sig.parameters) From f4ed7c9630491801db7d4ab8becfec50440e5f2f Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 18 Aug 2020 18:24:03 +0100 Subject: [PATCH 17/43] Apply suggestions from code review Co-authored-by: Nick Robinson Co-authored-by: willtebbutt --- src/ChainRulesCore.jl | 2 +- src/rules.jl | 3 ++- test/demos/forwarddiffzero.jl | 6 +++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 0d1515365..139428ae9 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -5,7 +5,7 @@ using MuladdMacro: @muladd export on_new_rule, refresh_rules # generation tools export frule, rrule # core function export @scalar_rule, @thunk # defination helper macros -export canonicalize, extern, unthunk # differnetial operations +export canonicalize, extern, unthunk # differential operations # differentials export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Zero, AbstractZero, AbstractThunk export NO_FIELDS diff --git a/src/rules.jl b/src/rules.jl index 241ae44d6..9bb603e48 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -144,7 +144,8 @@ last_refresh(::typeof(frule)) = LAST_REFRESH_FRULE last_refresh(::typeof(rrule)) = LAST_REFRESH_RRULE """ - refresh_rules(); refresh_rules(frule); refresh_rules(rrule) + refresh_rules() + refresh_rules(frule | rrule) This triggers all [`on_new_rule`](@ref) hooks to run on any newly defined rules. It is *automatically* run when ever a package is loaded, or a file is `include`d. diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl index f4bc4bf4d..2ba8d7928 100644 --- a/test/demos/forwarddiffzero.jl +++ b/test/demos/forwarddiffzero.jl @@ -43,14 +43,14 @@ end ######################################### # Initial rule setup -@scalar_rule x + y (One(), One()) -@scalar_rule x - y (One(), -1) +@scalar_rule x + y (1, 1) +@scalar_rule x - y (1, -1) on_new_rule(define_dual_overload, frule) # add a rule later also function ChainRulesCore.frule((_, Δx, Δy), ::typeof(*), x::Number, y::Number) - return (x * y, (Δx * y + x * Δy)) + return (x * y, Δx * y + x * Δy) end # Manual refresh needed as new rule added in same file as AD after the `on_new_rule` call From 1872cb1769ae63e53a5de1491700bddb2a897dae Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 18 Aug 2020 18:25:48 +0100 Subject: [PATCH 18/43] more comments Co-authored-by: Nick Robinson Co-authored-by: willtebbutt --- src/rules.jl | 1 + test/demos/forwarddiffzero.jl | 2 ++ test/demos/reversediffzero.jl | 2 +- 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index 9bb603e48..c8cd3ff59 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -122,6 +122,7 @@ function on_new_rule(hook_fun, rule_kind) end function __init__() + # Need to refresh rules when a module is loaded or a file is `include`d. push!(Base.package_callbacks, pkgid -> refresh_rules()) push!(Base.include_callbacks, (mod, filename) -> refresh_rules()) end diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl index 2ba8d7928..7e1a406ed 100644 --- a/test/demos/forwarddiffzero.jl +++ b/test/demos/forwarddiffzero.jl @@ -4,6 +4,8 @@ module ForwardDiffZero using ChainRulesCore using Test +# Note that we never directly define Dual Number Arithmetic on Dual numbers +# instread it is automatically defined from the `frules` struct Dual <: Real primal::Float64 diff::Float64 diff --git a/test/demos/reversediffzero.jl b/test/demos/reversediffzero.jl index 05f9b12ee..930f9c521 100644 --- a/test/demos/reversediffzero.jl +++ b/test/demos/reversediffzero.jl @@ -17,7 +17,7 @@ function Tracked(propagate, primal, tape) return v end -"An input, a Leaf in Nabla terms. No inputs of its on to propagate to." +"An input, a Leaf in Nabla terms. No inputs of its own to propagate to." function Tracked(primal, tape) # don't actually need to put these on the tape, since they don't need to # propagate anything From 37dda5868c3db72b6875eaf81818ae9a9c430886 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 19 Aug 2020 14:10:36 +0100 Subject: [PATCH 19/43] Apply suggestions from code review --- test/demos/reversediffzero.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/demos/reversediffzero.jl b/test/demos/reversediffzero.jl index 930f9c521..04ed026aa 100644 --- a/test/demos/reversediffzero.jl +++ b/test/demos/reversediffzero.jl @@ -3,10 +3,10 @@ module ReverseDiffZero using ChainRulesCore using Test -struct Tracked <: Real - propagate::Function +struct Tracked{F} <: Real + propagate::F primal::Float64 - tape::Vector{Any} # a reference to a shared tape + tape::Vector{Tracked} # a reference to a shared tape grad::Base.RefValue{Float64} # current accumulated gradient end @@ -38,7 +38,6 @@ get_tape(ds) = something(tape.(ds)...) # propagate the currently stored gradient back to my inputs. propagate!(d::Tracked) = d.propagate(d.grad[]) -propagate!(d) = nothing # Accumulate gradient, if the value is being tracked. accum!(d::Tracked, x̄) = d.grad[] += x̄ From 6a04dac573427321f946bd1a27a37b94eb6f2a67 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 19 Aug 2020 14:12:25 +0100 Subject: [PATCH 20/43] use paritial for all deriviative parts in demos --- test/demos/forwarddiffzero.jl | 10 +++++----- test/demos/reversediffzero.jl | 20 ++++++++++---------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl index 7e1a406ed..af6957a31 100644 --- a/test/demos/forwarddiffzero.jl +++ b/test/demos/forwarddiffzero.jl @@ -8,14 +8,14 @@ using Test # instread it is automatically defined from the `frules` struct Dual <: Real primal::Float64 - diff::Float64 + partial::Float64 end primal(d::Dual) = d.primal -diff(d::Dual) = d.diff +partial(d::Dual) = d.partial primal(d::Real) = d -diff(d::Real) = 0.0 +partial(d::Real) = 0.0 # needed for ^ to work from having `*` defined Base.to_power_type(x::Dual) = x @@ -31,7 +31,7 @@ function define_dual_overload(sig) # we use the function call overloading form as it lets us avoid namespacing issues # as we can directly interpolate the function type into to the AST. function (op::$opT)(dual_args::Vararg{Union{Dual, Float64}, $N}; kwargs...) - ȧrgs = (NO_FIELDS, diff.(dual_args)...) + ȧrgs = (NO_FIELDS, partial.(dual_args)...) args = (op, primal.(dual_args)...) res = frule(ȧrgs, args...; kwargs...) res === nothing && error("Apparently no rule for $($sig)), but we really thought there was, args=($args)") @@ -60,7 +60,7 @@ refresh_rules(); function derv(f, arg) duals = Dual(arg, one(arg)) - return diff(f(duals...)) + return partial(f(duals...)) end @testset "ForwardDiffZero" begin diff --git a/test/demos/reversediffzero.jl b/test/demos/reversediffzero.jl index 04ed026aa..3d1c95fa8 100644 --- a/test/demos/reversediffzero.jl +++ b/test/demos/reversediffzero.jl @@ -7,7 +7,7 @@ struct Tracked{F} <: Real propagate::F primal::Float64 tape::Vector{Tracked} # a reference to a shared tape - grad::Base.RefValue{Float64} # current accumulated gradient + partial::Base.RefValue{Float64} # current accumulated sensitivity end "An intermediate value, a Branch in Nabla terms." @@ -27,8 +27,8 @@ end primal(d::Tracked) = d.primal primal(d) = d -grad(d::Tracked) = d.grad[] -grad(d) = nothing +partial(d::Tracked) = d.partial[] +partial(d) = nothing tape(d::Tracked) = d.tape tape(d) = nothing @@ -36,11 +36,11 @@ tape(d) = nothing # we have many inputs grab the tape from the first one that is tracked get_tape(ds) = something(tape.(ds)...) -# propagate the currently stored gradient back to my inputs. -propagate!(d::Tracked) = d.propagate(d.grad[]) +# propagate the currently stored partialient back to my inputs. +propagate!(d::Tracked) = d.propagate(d.partial[]) -# Accumulate gradient, if the value is being tracked. -accum!(d::Tracked, x̄) = d.grad[] += x̄ +# Accumulate partialient, if the value is being tracked. +accum!(d::Tracked, x̄) = d.partial[] += x̄ accum!(d, x̄) = nothing function define_tracked_overload(sig) @@ -59,7 +59,7 @@ function define_tracked_overload(sig) y, y_pullback = res t = get_tape(tracked_args) y_tracked = Tracked(y, t) do ȳ - # pull this gradient back and propagate it to the inputs gradient stores + # pull this partialient back and propagate it to the inputs partialient stores _, ārgs = Iterators.peel(y_pullback(ȳ)) accum!.(tracked_args, ārgs) end @@ -80,10 +80,10 @@ function derv(f, args::Vararg; kwargs...) accum!(tracked_output, ōut) for op in reverse(the_tape) # by going down the tape backwards we know we will - # have accumulated its gradient fully + # have accumulated its partialient fully propagate!(op) end - return grad.(tracked_inputs) + return partial.(tracked_inputs) end return back(one(out)) end From dd083afa0d8c62017e550ef2968b92011c94532f Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 19 Aug 2020 14:13:36 +0100 Subject: [PATCH 21/43] remove debug stuff --- test/demos/forwarddiffzero.jl | 5 +---- test/demos/reversediffzero.jl | 8 +++----- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl index af6957a31..02ef41c16 100644 --- a/test/demos/forwarddiffzero.jl +++ b/test/demos/forwarddiffzero.jl @@ -33,13 +33,10 @@ function define_dual_overload(sig) function (op::$opT)(dual_args::Vararg{Union{Dual, Float64}, $N}; kwargs...) ȧrgs = (NO_FIELDS, partial.(dual_args)...) args = (op, primal.(dual_args)...) - res = frule(ȧrgs, args...; kwargs...) - res === nothing && error("Apparently no rule for $($sig)), but we really thought there was, args=($args)") - y, ẏ = res + y, ẏ = frule(ȧrgs, args...; kwargs...) return Dual(y, ẏ) # if y, ẏ are not `Float64` this will error. end end - #@show fdef eval(fdef) end diff --git a/test/demos/reversediffzero.jl b/test/demos/reversediffzero.jl index 3d1c95fa8..b3626bf0f 100644 --- a/test/demos/reversediffzero.jl +++ b/test/demos/reversediffzero.jl @@ -54,11 +54,9 @@ function define_tracked_overload(sig) # as we can directly interpolate the function type into to the AST. function (op::$opT)(tracked_args::Vararg{Union{Tracked, Float64}, $N}; kwargs...) args = (op, primal.(tracked_args)...) - res = rrule(args...; kwargs...) - res === nothing && error("Apparently no rule for $($sig)), but we really thought there was, args=($args)") - y, y_pullback = res - t = get_tape(tracked_args) - y_tracked = Tracked(y, t) do ȳ + y, y_pullback = rrule(args...; kwargs...) + the_tape = get_tape(tracked_args) + y_tracked = Tracked(y, the_tape) do ȳ # pull this partialient back and propagate it to the inputs partialient stores _, ārgs = Iterators.peel(y_pullback(ȳ)) accum!.(tracked_args, ārgs) From 562fe725b46db2190b7bc34e7a2546faf813b455 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 19 Aug 2020 14:44:16 +0100 Subject: [PATCH 22/43] tweak comments etc --- test/demos/forwarddiffzero.jl | 3 ++- test/demos/reversediffzero.jl | 23 ++++++++++++++--------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl index 02ef41c16..50d543c0a 100644 --- a/test/demos/forwarddiffzero.jl +++ b/test/demos/forwarddiffzero.jl @@ -5,7 +5,7 @@ using ChainRulesCore using Test # Note that we never directly define Dual Number Arithmetic on Dual numbers -# instread it is automatically defined from the `frules` +# instead it is automatically defined from the `frules` struct Dual <: Real primal::Float64 partial::Float64 @@ -55,6 +55,7 @@ end # Manual refresh needed as new rule added in same file as AD after the `on_new_rule` call refresh_rules(); +"Do a calculus. `f` should have a single input." function derv(f, arg) duals = Dual(arg, one(arg)) return partial(f(duals...)) diff --git a/test/demos/reversediffzero.jl b/test/demos/reversediffzero.jl index b3626bf0f..d4aeb36ef 100644 --- a/test/demos/reversediffzero.jl +++ b/test/demos/reversediffzero.jl @@ -17,11 +17,13 @@ function Tracked(propagate, primal, tape) return v end +"Maker for inputs (leaves) that don't need to propagate." +struct NoPropagate end + "An input, a Leaf in Nabla terms. No inputs of its own to propagate to." function Tracked(primal, tape) - # don't actually need to put these on the tape, since they don't need to - # propagate anything - return Tracked(_->nothing, primal, tape, Ref(zero(primal))) + # don't actually need to put these on the tape, since they don't need to propagate + return Tracked(NoPropagate(), primal, tape, Ref(zero(primal))) end primal(d::Tracked) = d.primal @@ -33,16 +35,17 @@ partial(d) = nothing tape(d::Tracked) = d.tape tape(d) = nothing -# we have many inputs grab the tape from the first one that is tracked +"we have many inputs grab the tape from the first one that is tracked" get_tape(ds) = something(tape.(ds)...) -# propagate the currently stored partialient back to my inputs. +"propagate the currently stored partialient back to my inputs." propagate!(d::Tracked) = d.propagate(d.partial[]) -# Accumulate partialient, if the value is being tracked. +"Accumulate the sensitivity, if the value is being tracked." accum!(d::Tracked, x̄) = d.partial[] += x̄ accum!(d, x̄) = nothing +"What to do when a new rrule is declared" function define_tracked_overload(sig) opT, argTs = Iterators.peel(sig.parameters) fieldcount(opT) == 0 || return # not handling functors @@ -67,8 +70,9 @@ function define_tracked_overload(sig) eval(fdef) end +"Do a calculus. `f` should have a single output." function derv(f, args::Vararg; kwargs...) - the_tape = Vector{Any}() + the_tape = Vector{Tracked}() tracked_inputs = Tracked.(args, Ref(the_tape)) tracked_output = f(tracked_inputs...; kwargs...) @assert tape(tracked_output) === the_tape @@ -76,9 +80,9 @@ function derv(f, args::Vararg; kwargs...) out = primal(tracked_output) function back(ōut) accum!(tracked_output, ōut) + # by going down the tape backwards we know we will have fully accumulated partials + # before propagating them onwards for op in reverse(the_tape) - # by going down the tape backwards we know we will - # have accumulated its partialient fully propagate!(op) end return partial.(tracked_inputs) @@ -127,3 +131,4 @@ refresh_rules(); @test derv(quux, 11.1) == (4*(2+3) + 5*(2+3),) end end # module + From 0422e9c8f9b6d525277344829b7cacfadedffa29 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 19 Aug 2020 19:02:46 +0100 Subject: [PATCH 23/43] start writing docs for using overload generation --- docs/Manifest.toml | 2 +- docs/Project.toml | 1 + docs/make.jl | 5 ++++ docs/src/autodiff/operator_overloading.md | 30 +++++++++++++++++++++++ docs/src/autodiff/overview.md | 22 +++++++++++++++++ test/demos/forwarddiffzero.jl | 1 - 6 files changed, 59 insertions(+), 2 deletions(-) create mode 100644 docs/src/autodiff/operator_overloading.md create mode 100644 docs/src/autodiff/overview.md diff --git a/docs/Manifest.toml b/docs/Manifest.toml index c58a79af6..1eb32c944 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -91,7 +91,7 @@ deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" [[REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets"] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" [[Random]] diff --git a/docs/Project.toml b/docs/Project.toml index 66770708a..a39e1b1da 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -2,6 +2,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterTools = "35a29f4d-8980-5a13-9543-d66fff28ecb8" +Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" [compat] Documenter = "0.25" diff --git a/docs/make.jl b/docs/make.jl index 8d1012ed1..7abed7a56 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,6 +1,7 @@ using ChainRulesCore using Documenter using DocumenterTools: Themes +using Markdown DocMeta.setdocmeta!( ChainRulesCore, @@ -39,6 +40,10 @@ makedocs( "Design" => [ "Many Differential Types" => "design/many_differentials.md", ], + "Usage in an AD" => [ + "Overview" => "autodiff/overview.md", + "Operator Overloading" => "autodiff/operator_overloading.md" + ], "API" => "api.md", ], strict=true, diff --git a/docs/src/autodiff/operator_overloading.md b/docs/src/autodiff/operator_overloading.md new file mode 100644 index 000000000..f4a01b20e --- /dev/null +++ b/docs/src/autodiff/operator_overloading.md @@ -0,0 +1,30 @@ +# Operator Overloading + +## Examples + +### ForwardDiffZero + +````@eval +using Markdown +code = read(joinpath(@__DIR__,"../../../test/demos/forwarddiffzero.jl"), String) +code = replace(code, raw"$" => raw"\$") +Markdown.parse(""" +```julia +$(code) +``` +""") +```` + +### ReverseDiffZero + +````@eval +using Markdown +code = read(joinpath(@__DIR__,"../../../test/demos/reversediffzero.jl"), String) +code = replace(code, raw"$" => raw"\$") +Markdown.parse(""" +```julia +$(code) +``` +""") +```` + diff --git a/docs/src/autodiff/overview.md b/docs/src/autodiff/overview.md new file mode 100644 index 000000000..eaa78be75 --- /dev/null +++ b/docs/src/autodiff/overview.md @@ -0,0 +1,22 @@ +# Using ChainRules in your AutoDiff system + +This section is for authors of AD systems. +It assumes a pretty solid understanding of Julia, and of Automatic Differentiation. +It explains how to make use of ChainRule's rule-sets, +to avoid having to code all your own AD primitives / custom sensitives. + +There are 3 main ways to use ChainRules in your AutoDiff system. + +1. [Operation Overloading Generation](autodiff/operator_overloading) + - This is primarily intended for operator overloading based AD systems which will generate overloads for primal function based for their overloaded types based on the existance of an `rrule`/`frule`. + - A source code generation based AD can also use this by overloading their tranform generating function directly so as not to generate a tranform but to just return a result. + - This does not play nice with Revise.jl, modifying or especially deleting rules may not be reflected. +2. Source code transform based on `rrule`/`frule` method-table + - Always use `rrule`/`frule` iff and only if use the rules that exist, else generate normal AD path. + - This avoids having branches in your generated code. + - This requires maintaining your own back-edges + - This is pretty hard-code even by the standard of source code tranformations +3. Source code tranform based on inserting branches that check of `rrule`/`frule` return `nothing` + - if the `rrule`/`frule` returns a rule result then use it, if it return `nothing` then do normal AD path + - In theory type inference optimizes these branchs out; in practice it may not. + - This is a fairly simple Cassette overdub of all calls, and is suitable for overloading based AD or source code transformation. diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl index 50d543c0a..3bc6c86d7 100644 --- a/test/demos/forwarddiffzero.jl +++ b/test/demos/forwarddiffzero.jl @@ -1,6 +1,5 @@ "The simplest viable forward mode a AD, only supports `Float64`" module ForwardDiffZero - using ChainRulesCore using Test From 65e8c3dbb6f97f8c9b3029300f1085483e900b29 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 20 Aug 2020 13:17:21 +0100 Subject: [PATCH 24/43] working on docs --- docs/src/autodiff/operator_overloading.md | 8 ++--- docs/src/autodiff/overview.md | 20 +++++------ test/demos/forwarddiffzero.jl | 30 +++++++++------- test/demos/reversediffzero.jl | 42 +++++++++++++---------- 4 files changed, 54 insertions(+), 46 deletions(-) diff --git a/docs/src/autodiff/operator_overloading.md b/docs/src/autodiff/operator_overloading.md index f4a01b20e..8506f3f69 100644 --- a/docs/src/autodiff/operator_overloading.md +++ b/docs/src/autodiff/operator_overloading.md @@ -6,11 +6,9 @@ ````@eval using Markdown -code = read(joinpath(@__DIR__,"../../../test/demos/forwarddiffzero.jl"), String) -code = replace(code, raw"$" => raw"\$") Markdown.parse(""" ```julia -$(code) +$(read(joinpath(@__DIR__,"../../../test/demos/forwarddiffzero.jl"), String)) ``` """) ```` @@ -19,11 +17,9 @@ $(code) ````@eval using Markdown -code = read(joinpath(@__DIR__,"../../../test/demos/reversediffzero.jl"), String) -code = replace(code, raw"$" => raw"\$") Markdown.parse(""" ```julia -$(code) +$(read(joinpath(@__DIR__,"../../../test/demos/reversediffzero.jl"), String)) ``` """) ```` diff --git a/docs/src/autodiff/overview.md b/docs/src/autodiff/overview.md index eaa78be75..6acbfe35f 100644 --- a/docs/src/autodiff/overview.md +++ b/docs/src/autodiff/overview.md @@ -1,22 +1,22 @@ # Using ChainRules in your AutoDiff system This section is for authors of AD systems. -It assumes a pretty solid understanding of Julia, and of Automatic Differentiation. -It explains how to make use of ChainRule's rule-sets, +It assumes a pretty solid understanding of Julia, and of automatic differentiation. +It explains how to make use of ChainRule's rule sets, to avoid having to code all your own AD primitives / custom sensitives. -There are 3 main ways to use ChainRules in your AutoDiff system. +There are 3 main ways to access ChainRules rule sets in your AutoDiff system. 1. [Operation Overloading Generation](autodiff/operator_overloading) - This is primarily intended for operator overloading based AD systems which will generate overloads for primal function based for their overloaded types based on the existance of an `rrule`/`frule`. - - A source code generation based AD can also use this by overloading their tranform generating function directly so as not to generate a tranform but to just return a result. - - This does not play nice with Revise.jl, modifying or especially deleting rules may not be reflected. -2. Source code transform based on `rrule`/`frule` method-table + - A source code generation based AD can also use this by overloading their transform generating function directly so as not to recursively generate a transform but to just return the rule. + - This does not play nice with Revise.jl, adding or modifying rules in loaded files will not be reflected until a manual refresh, and deleting rules will not be reflected at all. +2. Source code tranform based on inserting branches that check of `rrule`/`frule` return `nothing` + - if the `rrule`/`frule` returns a rule result then use it, if it return `nothing` then do normal AD path + - In theory type inference optimizes these branchs out; in practice it may not. + - This is a fairly simple Cassette overdub (or similar) of all calls, and is suitable for overloading based AD or source code transformation. +3. Source code transform based on `rrule`/`frule` method-table - Always use `rrule`/`frule` iff and only if use the rules that exist, else generate normal AD path. - This avoids having branches in your generated code. - This requires maintaining your own back-edges - This is pretty hard-code even by the standard of source code tranformations -3. Source code tranform based on inserting branches that check of `rrule`/`frule` return `nothing` - - if the `rrule`/`frule` returns a rule result then use it, if it return `nothing` then do normal AD path - - In theory type inference optimizes these branchs out; in practice it may not. - - This is a fairly simple Cassette overdub of all calls, and is suitable for overloading based AD or source code transformation. diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl index 3bc6c86d7..2f5dfc690 100644 --- a/test/demos/forwarddiffzero.jl +++ b/test/demos/forwarddiffzero.jl @@ -3,6 +3,13 @@ module ForwardDiffZero using ChainRulesCore using Test +######################################### +# Initial rule setup +@scalar_rule x + y (1, 1) +@scalar_rule x - y (1, -1) +########################## +# Define the AD + # Note that we never directly define Dual Number Arithmetic on Dual numbers # instead it is automatically defined from the `frules` struct Dual <: Real @@ -39,13 +46,18 @@ function define_dual_overload(sig) eval(fdef) end -######################################### -# Initial rule setup -@scalar_rule x + y (1, 1) -@scalar_rule x - y (1, -1) - +# !Important!: Attach the define function to the `on_new_rule` hook on_new_rule(define_dual_overload, frule) +"Do a calculus. `f` should have a single input." +function derv(f, arg) + duals = Dual(arg, one(arg)) + return partial(f(duals...)) +end + +# End AD definition +################################ + # add a rule later also function ChainRulesCore.frule((_, Δx, Δy), ::typeof(*), x::Number, y::Number) return (x * y, Δx * y + x * Δy) @@ -54,12 +66,6 @@ end # Manual refresh needed as new rule added in same file as AD after the `on_new_rule` call refresh_rules(); -"Do a calculus. `f` should have a single input." -function derv(f, arg) - duals = Dual(arg, one(arg)) - return partial(f(duals...)) -end - @testset "ForwardDiffZero" begin foo(x) = x + x @test derv(foo, 1.6) == 2 @@ -80,4 +86,4 @@ end @test derv(quux, 11.1) == 4*(2+3) + 5*(2+3) end -end # module +end # module \ No newline at end of file diff --git a/test/demos/reversediffzero.jl b/test/demos/reversediffzero.jl index d4aeb36ef..31a4737b3 100644 --- a/test/demos/reversediffzero.jl +++ b/test/demos/reversediffzero.jl @@ -3,6 +3,13 @@ module ReverseDiffZero using ChainRulesCore using Test +######################################### +# Initial rule setup +@scalar_rule x + y (1, 1) +@scalar_rule x - y (1, -1) +########################## +#Define the AD + struct Tracked{F} <: Real propagate::F primal::Float64 @@ -45,6 +52,9 @@ propagate!(d::Tracked) = d.propagate(d.partial[]) accum!(d::Tracked, x̄) = d.partial[] += x̄ accum!(d, x̄) = nothing +# needed for ^ to work from having `*` defined +Base.to_power_type(x::Tracked) = x + "What to do when a new rrule is declared" function define_tracked_overload(sig) opT, argTs = Iterators.peel(sig.parameters) @@ -70,6 +80,9 @@ function define_tracked_overload(sig) eval(fdef) end +# !Important!: Attach the define function to the `on_new_rule` hook +on_new_rule(define_tracked_overload, rrule) + "Do a calculus. `f` should have a single output." function derv(f, args::Vararg; kwargs...) the_tape = Vector{Tracked}() @@ -77,32 +90,25 @@ function derv(f, args::Vararg; kwargs...) tracked_output = f(tracked_inputs...; kwargs...) @assert tape(tracked_output) === the_tape + # Now the backward pass out = primal(tracked_output) - function back(ōut) - accum!(tracked_output, ōut) - # by going down the tape backwards we know we will have fully accumulated partials - # before propagating them onwards - for op in reverse(the_tape) - propagate!(op) - end - return partial.(tracked_inputs) + ōut = one(out) + accum!(tracked_output, ōut) + # By going down the tape backwards we know we will have fully accumulated partials + # before propagating them onwards + for op in reverse(the_tape) + propagate!(op) end - return back(one(out)) + return partial.(tracked_inputs) end -# needed for ^ to work from having `*` defined -Base.to_power_type(x::Tracked) = x - -######################################### -# Initial rule setup -@scalar_rule x + y (One(), One()) -@scalar_rule x - y (One(), -1) - -on_new_rule(define_tracked_overload, rrule) +# End AD definition +################################ # add a rule later also function ChainRulesCore.rrule(::typeof(*), x::Number, y::Number) function times_pullback(ΔΩ) + # we will use thunks here to show we handle them fine. return (NO_FIELDS, @thunk(ΔΩ * y'), @thunk(x' * ΔΩ)) end return x * y, times_pullback From 6e587543ed0736e03ecca7b04b56adbbc83f5936 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 20 Aug 2020 14:50:35 +0100 Subject: [PATCH 25/43] finish first pass at docs --- docs/Manifest.toml | 8 +- docs/src/api.md | 8 ++ docs/src/autodiff/operator_overloading.md | 14 +++ docs/src/autodiff/overview.md | 2 +- src/rules.jl | 115 +------------------ src/ruleset_loading.jl | 132 ++++++++++++++++++++++ 6 files changed, 160 insertions(+), 119 deletions(-) create mode 100644 src/ruleset_loading.jl diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 1eb32c944..da934f718 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -31,9 +31,9 @@ version = "0.8.2" [[Documenter]] deps = ["Base64", "Dates", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] -git-tree-sha1 = "1c593d1efa27437ed9dd365d1143c594b563e138" +git-tree-sha1 = "fb1ff838470573adc15c71ba79f8d31328f035da" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "0.25.1" +version = "0.25.2" [[DocumenterTools]] deps = ["Base64", "DocStringExtensions", "Documenter", "FileWatching", "LibGit2", "Sass"] @@ -78,9 +78,9 @@ version = "0.2.2" [[Parsers]] deps = ["Dates", "Test"] -git-tree-sha1 = "10134f2ee0b1978ae7752c41306e131a684e1f06" +git-tree-sha1 = "8077624b3c450b15c087944363606a6ba12f925e" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "1.0.7" +version = "1.0.10" [[Pkg]] deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] diff --git a/docs/src/api.md b/docs/src/api.md index 0aa9d44d2..ad34569e9 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -27,8 +27,16 @@ Pages = [ Private = false ``` +## Ruleset Loading +```@autodocs +Modules = [ChainRulesCore] +Pages = ["ruleset_loading.jl"] +Private = false +``` + ## Internal ```@docs ChainRulesCore.AbstractDifferential ChainRulesCore.debug_mode +ChainRulesCore.clear_new_rule_hooks! ``` diff --git a/docs/src/autodiff/operator_overloading.md b/docs/src/autodiff/operator_overloading.md index 8506f3f69..1149800ff 100644 --- a/docs/src/autodiff/operator_overloading.md +++ b/docs/src/autodiff/operator_overloading.md @@ -1,8 +1,21 @@ # Operator Overloading +The principle interface for using the operator overload generation method is [`on_new_rule`](@ref). +This function allows one to register a hook to be run every time a new rule is defined. +The hook receives a signature type-type as input, and generally will use `eval` to define +and overload of AD systems overloaded type. +For example, using the signature type `Tuple{typeof(+), Real, Real}` to define +`+(::DualNumber, ::DualNumber)` as calling the `frule` for `+`. +A signature type tuple always has the form: +`Tuple{typeof(operation), typeof{pos_arg1}, typeof{pos_arg2...}}`, where `pos_arg1` is the +first positional argument. +One can dispatch on the signature type, to make rules with argument types your AD does not support not call `eval`; +or more simply you can just use conditions for this. + ## Examples ### ForwardDiffZero +The overload generation hook in this example is: `define_dual_overload`. ````@eval using Markdown @@ -14,6 +27,7 @@ $(read(joinpath(@__DIR__,"../../../test/demos/forwarddiffzero.jl"), String)) ```` ### ReverseDiffZero +The overload generation hook in this example is: `define_tracked_overload`. ````@eval using Markdown diff --git a/docs/src/autodiff/overview.md b/docs/src/autodiff/overview.md index 6acbfe35f..ce1a97f32 100644 --- a/docs/src/autodiff/overview.md +++ b/docs/src/autodiff/overview.md @@ -7,7 +7,7 @@ to avoid having to code all your own AD primitives / custom sensitives. There are 3 main ways to access ChainRules rule sets in your AutoDiff system. -1. [Operation Overloading Generation](autodiff/operator_overloading) +1. [Operation Overloading Generation](operator_overloading.html) - This is primarily intended for operator overloading based AD systems which will generate overloads for primal function based for their overloaded types based on the existance of an `rrule`/`frule`. - A source code generation based AD can also use this by overloading their transform generating function directly so as not to recursively generate a transform but to just return the rule. - This does not play nice with Revise.jl, adding or modifying rules in loaded files will not be reflected until a manual refresh, and deleting rules will not be reflected at all. diff --git a/src/rules.jl b/src/rules.jl index c8cd3ff59..b83cda45a 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -1,7 +1,3 @@ -##### -##### `frule`/`rrule` -##### - """ frule((Δf, Δx...), f, x...) @@ -95,113 +91,4 @@ true See also: [`frule`](@ref), [`@scalar_rule`](@ref) """ -rrule(::Any, ::Vararg{Any}; kwargs...) = nothing - - -####################################################################### -# Infastructure to support generating overloads from rules. - -const NEW_RRULE_HOOKS = Function[] -const NEW_FRULE_HOOKS = Function[] -_hook_list(::typeof(rrule)) = NEW_RRULE_HOOKS -_hook_list(::typeof(frule)) = NEW_FRULE_HOOKS - -""" - on_new_rule(sig->eval...), frule | rrule) -""" -function on_new_rule(hook_fun, rule_kind) - # get all the existing rules - ret = map(_rule_list(rule_kind)) do method - sig = _primal_sig(rule_kind, method) - _safe_hook_fun(hook_fun, sig) - end - - # register hook for new rules - push!(_hook_list(rule_kind), hook_fun) - return ret -end - -function __init__() - # Need to refresh rules when a module is loaded or a file is `include`d. - push!(Base.package_callbacks, pkgid -> refresh_rules()) - push!(Base.include_callbacks, (mod, filename) -> refresh_rules()) -end - - -""" - _rule_list(frule | rrule) - -Returns a list of all the methods of the currently defined rules of the given kind. -Excluding the fallback rule that returns `nothing` for every input. -""" -_rule_list(rule_kind) = (m for m in methods(rule_kind) if m.module != @__MODULE__) -# ^ The fallback rules are the only rules defined in ChainRules core so that is how we skip them. - - - -const LAST_REFRESH_RRULE = Ref(0) -const LAST_REFRESH_FRULE = Ref(0) -last_refresh(::typeof(frule)) = LAST_REFRESH_FRULE -last_refresh(::typeof(rrule)) = LAST_REFRESH_RRULE - -""" - refresh_rules() - refresh_rules(frule | rrule) - -This triggers all [`on_new_rule`](@ref) hooks to run on any newly defined rules. -It is *automatically* run when ever a package is loaded, or a file is `include`d. -It can also be manually called to run it directly, for example if a rule was defined -in the REPL or with-in the same file as the AD function. -""" -refresh_rules() = (refresh_rules(frule); refresh_rules(rrule)) -function refresh_rules(rule_kind) - already_done_world_age = last_refresh(rule_kind)[] - for method in _rule_list(rule_kind) - _defined_world(method) < already_done_world_age && continue - sig = _primal_sig(rule_kind, method) - _trigger_new_rule_hooks(rule_kind, sig) - end - - last_refresh(rule_kind)[] = _current_world() - return nothing -end - -@static if VERSION >= v"1.2" - _current_world() = Base.get_world_counter() - _defined_world(method) = method.primary_world -else - _current_world() = ccall(:jl_get_world_counter, UInt, ()) - _defined_world(method) = method.min_world -end - -""" - _primal_sig(frule|rule, rule_method | rule_sig) - -Returns the signature as a `Tuple{function_type, arg1_type, arg2_type,...}`. -""" -_primal_sig(rule_kind, method::Method) = _primal_sig(rule_kind, method.sig) -function _primal_sig(::typeof(frule), rule_sig::Type) - @assert rule_sig.parameters[1] == typeof(frule) - # need to skip frule and the deriviative info, so starting from the 3rd - return Tuple{rule_sig.parameters[3:end]...} -end - -function _primal_sig(::typeof(rrule), rule_sig::Type) - @assert rule_sig.parameters[1] == typeof(rrule) - # need to skip rrule so starting from the 2rd - return Tuple{rule_sig.parameters[2:end]...} -end - -function _trigger_new_rule_hooks(rule_kind, sig) - for hook_fun in _hook_list(rule_kind) - _safe_hook_fun(hook_fun, sig) - end -end - -function _safe_hook_fun(hook_fun, sig) - try - hook_fun(sig) - catch err - @error "Error triggering hook" hook_fun sig exception=err - end -end +rrule(::Any, ::Vararg{Any}; kwargs...) = nothing \ No newline at end of file diff --git a/src/ruleset_loading.jl b/src/ruleset_loading.jl new file mode 100644 index 000000000..cfd280c41 --- /dev/null +++ b/src/ruleset_loading.jl @@ -0,0 +1,132 @@ +# Infastructure to support generating overloads from rules. + +const NEW_RRULE_HOOKS = Function[] +const NEW_FRULE_HOOKS = Function[] +_hook_list(::typeof(rrule)) = NEW_RRULE_HOOKS +_hook_list(::typeof(frule)) = NEW_FRULE_HOOKS + +""" + on_new_rule(hook, frule | rrule) + +Register a `hook` function to run when new rules are defined. +The hook receives a signature type-type as input, and generally will use `eval` to define +and overload of AD systems overloaded type. +For example, using the signature type `Tuple{typeof(+), Real, Real}` to define +`+(::DualNumber, ::DualNumber)` as calling the `frule` for `+`. +A signature type tuple always has the form: +`Tuple{typeof(operation), typeof{pos_arg1}, typeof{pos_arg2...}}`, where `pos_arg1` is the +first positional argument + +The hooks are automatically run on new rules whenever a package is loaded, or a file is +`include`d. They can be manually triggered by [`refresh_rules`](@ref). +When a hook is first registered with `on_new_rule` it is run on all existing rules. +""" +function on_new_rule(hook_fun, rule_kind) + # get all the existing rules + ret = map(_rule_list(rule_kind)) do method + sig = _primal_sig(rule_kind, method) + _safe_hook_fun(hook_fun, sig) + end + + # register hook for new rules + push!(_hook_list(rule_kind), hook_fun) + return ret +end + +""" + clear_new_rule_hooks!(frule|rrule) + +Clears all hooks that were registered with corresponding [`on_new_rule`](@ref). +This is useful for while working interactively to define your rule generating hooks. +If you previously wrong an incorrect hook, you can use this to get rid of the old one. + +!!! warning + This absolutely should not be used in a package, as it will break any other AD system + using the rule hooks that might happen to be loaded. +""" +clear_new_rule_hooks!(rule_kind) = empty!(_hook_list(rule_kind)) + +function __init__() + # Need to refresh rules when a module is loaded or a file is `include`d. + push!(Base.package_callbacks, pkgid -> refresh_rules()) + push!(Base.include_callbacks, (mod, filename) -> refresh_rules()) +end + + +""" + _rule_list(frule | rrule) + +Returns a list of all the methods of the currently defined rules of the given kind. +Excluding the fallback rule that returns `nothing` for every input. +""" +_rule_list(rule_kind) = (m for m in methods(rule_kind) if m.module != @__MODULE__) +# ^ The fallback rules are the only rules defined in ChainRules core so that is how we skip them. + + + +const LAST_REFRESH_RRULE = Ref(0) +const LAST_REFRESH_FRULE = Ref(0) +last_refresh(::typeof(frule)) = LAST_REFRESH_FRULE +last_refresh(::typeof(rrule)) = LAST_REFRESH_RRULE + +""" + refresh_rules() + refresh_rules(frule | rrule) + +This triggers all [`on_new_rule`](@ref) hooks to run on any newly defined rules. +It is *automatically* run when ever a package is loaded, or a file is `include`d. +It can also be manually called to run it directly, for example if a rule was defined +in the REPL or with-in the same file as the AD function. +""" +refresh_rules() = (refresh_rules(frule); refresh_rules(rrule)) +function refresh_rules(rule_kind) + already_done_world_age = last_refresh(rule_kind)[] + for method in _rule_list(rule_kind) + _defined_world(method) < already_done_world_age && continue + sig = _primal_sig(rule_kind, method) + _trigger_new_rule_hooks(rule_kind, sig) + end + + last_refresh(rule_kind)[] = _current_world() + return nothing +end + +@static if VERSION >= v"1.2" + _current_world() = Base.get_world_counter() + _defined_world(method) = method.primary_world +else + _current_world() = ccall(:jl_get_world_counter, UInt, ()) + _defined_world(method) = method.min_world +end + +""" + _primal_sig(frule|rule, rule_method | rule_sig) + +Returns the signature as a `Tuple{function_type, arg1_type, arg2_type,...}`. +""" +_primal_sig(rule_kind, method::Method) = _primal_sig(rule_kind, method.sig) +function _primal_sig(::typeof(frule), rule_sig::Type) + @assert rule_sig.parameters[1] == typeof(frule) + # need to skip frule and the deriviative info, so starting from the 3rd + return Tuple{rule_sig.parameters[3:end]...} +end + +function _primal_sig(::typeof(rrule), rule_sig::Type) + @assert rule_sig.parameters[1] == typeof(rrule) + # need to skip rrule so starting from the 2rd + return Tuple{rule_sig.parameters[2:end]...} +end + +function _trigger_new_rule_hooks(rule_kind, sig) + for hook_fun in _hook_list(rule_kind) + _safe_hook_fun(hook_fun, sig) + end +end + +function _safe_hook_fun(hook_fun, sig) + try + hook_fun(sig) + catch err + @error "Error triggering hook" hook_fun sig exception=err + end +end From faa20873dbafdf99d9231b70e4479d032ab51c81 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 20 Aug 2020 18:10:02 +0100 Subject: [PATCH 26/43] more docs --- docs/src/autodiff/operator_overloading.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/src/autodiff/operator_overloading.md b/docs/src/autodiff/operator_overloading.md index 1149800ff..5ea4dd3fc 100644 --- a/docs/src/autodiff/operator_overloading.md +++ b/docs/src/autodiff/operator_overloading.md @@ -12,6 +12,13 @@ first positional argument. One can dispatch on the signature type, to make rules with argument types your AD does not support not call `eval`; or more simply you can just use conditions for this. +`refresh_rules`(@ref) is used to manually trigger the hook function on any new rules. +This is useful for example if new rules are define in the REPL, or if files defining rules were modified. +(Revise.jl will not automatically trigger). + +`clear_new_rule_hooks!`(@ref) clears all registered hooks. +It is useful to undo [`on_new_rule`] hook registration if you are iteratively developing your overload generation function. + ## Examples ### ForwardDiffZero From a909e2d764d1bbb46a78dd80fbec7f49fb1ab3c5 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 20 Aug 2020 19:10:40 +0100 Subject: [PATCH 27/43] handle Unionall Signatures --- src/ChainRulesCore.jl | 3 +- src/ruleset_loading.jl | 25 +++++++++------ test/demos/forwarddiffzero.jl | 1 + test/demos/reversediffzero.jl | 1 + test/ruleset_loading.jl | 57 +++++++++++++++++++++++++++++++++++ test/runtests.jl | 2 ++ 6 files changed, 79 insertions(+), 10 deletions(-) create mode 100644 test/ruleset_loading.jl diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 139428ae9..a163cb874 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -2,7 +2,7 @@ module ChainRulesCore using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize! using MuladdMacro: @muladd -export on_new_rule, refresh_rules # generation tools +export on_new_rule, refresh_rules, clear_new_rule_hooks! # generation tools export frule, rrule # core function export @scalar_rule, @thunk # defination helper macros export canonicalize, extern, unthunk # differential operations @@ -23,5 +23,6 @@ include("differential_arithmetic.jl") include("rules.jl") include("rule_definition_tools.jl") +include("ruleset_loading.jl") end # module diff --git a/src/ruleset_loading.jl b/src/ruleset_loading.jl index cfd280c41..9c76545d0 100644 --- a/src/ruleset_loading.jl +++ b/src/ruleset_loading.jl @@ -1,5 +1,12 @@ # Infastructure to support generating overloads from rules. +function __init__() + # Need to refresh rules when a module is loaded or a file is `include`d. + push!(Base.package_callbacks, pkgid -> refresh_rules()) + push!(Base.include_callbacks, (mod, filename) -> refresh_rules()) +end + + const NEW_RRULE_HOOKS = Function[] const NEW_FRULE_HOOKS = Function[] _hook_list(::typeof(rrule)) = NEW_RRULE_HOOKS @@ -46,12 +53,6 @@ If you previously wrong an incorrect hook, you can use this to get rid of the ol """ clear_new_rule_hooks!(rule_kind) = empty!(_hook_list(rule_kind)) -function __init__() - # Need to refresh rules when a module is loaded or a file is `include`d. - push!(Base.package_callbacks, pkgid -> refresh_rules()) - push!(Base.include_callbacks, (mod, filename) -> refresh_rules()) -end - """ _rule_list(frule | rrule) @@ -105,17 +106,23 @@ end Returns the signature as a `Tuple{function_type, arg1_type, arg2_type,...}`. """ _primal_sig(rule_kind, method::Method) = _primal_sig(rule_kind, method.sig) -function _primal_sig(::typeof(frule), rule_sig::Type) +function _primal_sig(::typeof(frule), rule_sig::DataType) @assert rule_sig.parameters[1] == typeof(frule) # need to skip frule and the deriviative info, so starting from the 3rd return Tuple{rule_sig.parameters[3:end]...} end - -function _primal_sig(::typeof(rrule), rule_sig::Type) +function _primal_sig(::typeof(rrule), rule_sig::DataType) @assert rule_sig.parameters[1] == typeof(rrule) # need to skip rrule so starting from the 2rd return Tuple{rule_sig.parameters[2:end]...} end +function _primal_sig(rule_kind, rule_sig::UnionAll) + # This looks a lot like Base.unwrap_unionall and Base.rewrap_unionall, but using those + # seems not to work + p_sig = _primal_sig(rule_kind, rule_sig.body) + return UnionAll(rule_sig.var, p_sig) +end + function _trigger_new_rule_hooks(rule_kind, sig) for hook_fun in _hook_list(rule_kind) diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl index 2f5dfc690..4cc1559a2 100644 --- a/test/demos/forwarddiffzero.jl +++ b/test/demos/forwarddiffzero.jl @@ -28,6 +28,7 @@ Base.to_power_type(x::Dual) = x function define_dual_overload(sig) + sig = Base.unwrap_unionall(sig) opT, argTs = Iterators.peel(sig.parameters) fieldcount(opT) == 0 || return # not handling functors all(Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops. diff --git a/test/demos/reversediffzero.jl b/test/demos/reversediffzero.jl index 31a4737b3..390c0e3fe 100644 --- a/test/demos/reversediffzero.jl +++ b/test/demos/reversediffzero.jl @@ -57,6 +57,7 @@ Base.to_power_type(x::Tracked) = x "What to do when a new rrule is declared" function define_tracked_overload(sig) + sig = Base.unwrap_unionall(sig) opT, argTs = Iterators.peel(sig.parameters) fieldcount(opT) == 0 || return # not handling functors all(Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops. diff --git a/test/ruleset_loading.jl b/test/ruleset_loading.jl new file mode 100644 index 000000000..72696e302 --- /dev/null +++ b/test/ruleset_loading.jl @@ -0,0 +1,57 @@ +@testset "ruleset_loading.jl" begin +#== + @testset "on_new_rule" begin + frule_history = [] + rrule_history = [] + on_new_rule(frule) do sig + op = sig.parameters[1] + push!(frule_history, op) + end + on_new_rule(rrule) do sig + op = sig.parameters[1] + push!(rrule_history, op) + end + + include("example_rules.jl") # hook should trigger on include + #refresh_rules() + @test Set(frule_history[end-1:end]) == Set((typeof(+), typeof(-))) + @test Set(rrule_history[end-1:end]) == Set((typeof(+), typeof(-))) + + clear_new_rule_hooks!(frule) + clear_new_rule_hooks!(rrule) + end + ==# + @testset "_primal_sig" begin + _primal_sig = ChainRulesCore._primal_sig + type_constraint_equal(T1, T2) = (T1 <: T2) && (T2 <: T1) + @testset "frule" begin + @test isequal( # DataType without shared type but with constraint + _primal_sig(frule, Tuple{typeof(frule), Any, typeof(*), Int, Vector{Int}}), + Tuple{typeof(*), Int, Vector{Int}} + ) + @test isequal( # UnionAall without shared type but with constraint + _primal_sig(frule, Tuple{typeof(frule), Any, typeof(*), T, Int} where T<:Real), + Tuple{typeof(*), T, Int} where T<:Real + ) + @test isequal( # UnionAall with share type + _primal_sig(frule, Tuple{typeof(frule), Any, typeof(*), T, Vector{T}} where T), + Tuple{typeof(*), T, Vector{T}} where T + ) + end + + @testset "rrule" begin + @test isequal( # DataType without shared type but with constraint + _primal_sig(rrule, Tuple{typeof(rrule), typeof(*), Int, Vector{Int}}), + Tuple{typeof(*), Int, Vector{Int}} + ) + @test isequal( # UnionAall without shared type but with constraint + _primal_sig(rrule, Tuple{typeof(rrule), typeof(*), T, Int} where T<:Real), + Tuple{typeof(*), T, Int} where T<:Real + ) + @test isequal( # UnionAall with share type + _primal_sig(rrule, Tuple{typeof(rrule), typeof(*), T, Vector{T}} where T), + Tuple{typeof(*), T, Vector{T}} where T + ) + end + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 90c997e04..8f995b354 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,8 +13,10 @@ using Test include("differentials/composite.jl") end + include("ruleset_loading.jl") include("rules.jl") + @testset "demos" begin include("demos/forwarddiffzero.jl") include("demos/reversediffzero.jl") From fb8cdf673d18638ba5fee0e26ec94f938597c574 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 20 Aug 2020 19:57:29 +0100 Subject: [PATCH 28/43] Stop refreshing rules on include_callback The include callback runs before the file is included, so it is not useful to us. I tested and the Package load callback runs after, so it is useful. --- docs/src/autodiff/operator_overloading.md | 3 ++- src/ruleset_loading.jl | 7 +++---- test/ruleset_loading.jl | 10 ++++++---- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/docs/src/autodiff/operator_overloading.md b/docs/src/autodiff/operator_overloading.md index 5ea4dd3fc..be2533f87 100644 --- a/docs/src/autodiff/operator_overloading.md +++ b/docs/src/autodiff/operator_overloading.md @@ -11,9 +11,10 @@ A signature type tuple always has the form: first positional argument. One can dispatch on the signature type, to make rules with argument types your AD does not support not call `eval`; or more simply you can just use conditions for this. +The the hook is automatically triggered whenever a package is loaded. `refresh_rules`(@ref) is used to manually trigger the hook function on any new rules. -This is useful for example if new rules are define in the REPL, or if files defining rules were modified. +This is useful for example if new rules are define in the REPL, or if a package defining rules is modified. (Revise.jl will not automatically trigger). `clear_new_rule_hooks!`(@ref) clears all registered hooks. diff --git a/src/ruleset_loading.jl b/src/ruleset_loading.jl index 9c76545d0..76c381c49 100644 --- a/src/ruleset_loading.jl +++ b/src/ruleset_loading.jl @@ -3,7 +3,6 @@ function __init__() # Need to refresh rules when a module is loaded or a file is `include`d. push!(Base.package_callbacks, pkgid -> refresh_rules()) - push!(Base.include_callbacks, (mod, filename) -> refresh_rules()) end @@ -24,8 +23,8 @@ A signature type tuple always has the form: `Tuple{typeof(operation), typeof{pos_arg1}, typeof{pos_arg2...}}`, where `pos_arg1` is the first positional argument -The hooks are automatically run on new rules whenever a package is loaded, or a file is -`include`d. They can be manually triggered by [`refresh_rules`](@ref). +The hooks are automatically run on new rules whenever a package is loaded. +They can be manually triggered by [`refresh_rules`](@ref). When a hook is first registered with `on_new_rule` it is run on all existing rules. """ function on_new_rule(hook_fun, rule_kind) @@ -75,7 +74,7 @@ last_refresh(::typeof(rrule)) = LAST_REFRESH_RRULE refresh_rules(frule | rrule) This triggers all [`on_new_rule`](@ref) hooks to run on any newly defined rules. -It is *automatically* run when ever a package is loaded, or a file is `include`d. +It is *automatically* run when ever a package is loaded. It can also be manually called to run it directly, for example if a rule was defined in the REPL or with-in the same file as the AD function. """ diff --git a/test/ruleset_loading.jl b/test/ruleset_loading.jl index 72696e302..81af1c9f8 100644 --- a/test/ruleset_loading.jl +++ b/test/ruleset_loading.jl @@ -1,5 +1,4 @@ @testset "ruleset_loading.jl" begin -#== @testset "on_new_rule" begin frule_history = [] rrule_history = [] @@ -12,15 +11,18 @@ push!(rrule_history, op) end - include("example_rules.jl") # hook should trigger on include - #refresh_rules() + # Now define some rules + @scalar_rule x + y (1, 1) + @scalar_rule x - y (1, -1) + refresh_rules() + @test Set(frule_history[end-1:end]) == Set((typeof(+), typeof(-))) @test Set(rrule_history[end-1:end]) == Set((typeof(+), typeof(-))) clear_new_rule_hooks!(frule) clear_new_rule_hooks!(rrule) end - ==# + @testset "_primal_sig" begin _primal_sig = ChainRulesCore._primal_sig type_constraint_equal(T1, T2) = (T1 <: T2) && (T2 <: T1) From b8d158166a1790b37a8bbbafebe110772533791d Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 20 Aug 2020 20:13:19 +0100 Subject: [PATCH 29/43] tweaks --- docs/make.jl | 6 +++--- src/ruleset_loading.jl | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 7abed7a56..0e160bf9e 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -37,13 +37,13 @@ makedocs( "Complex Numbers" => "complex.md", "Deriving Array Rules" => "arrays.md", "Debug Mode" => "debug_mode.md", - "Design" => [ - "Many Differential Types" => "design/many_differentials.md", - ], "Usage in an AD" => [ "Overview" => "autodiff/overview.md", "Operator Overloading" => "autodiff/operator_overloading.md" ], + "Design" => [ + "Many Differential Types" => "design/many_differentials.md", + ], "API" => "api.md", ], strict=true, diff --git a/src/ruleset_loading.jl b/src/ruleset_loading.jl index 76c381c49..5b1306954 100644 --- a/src/ruleset_loading.jl +++ b/src/ruleset_loading.jl @@ -1,7 +1,7 @@ # Infastructure to support generating overloads from rules. function __init__() - # Need to refresh rules when a module is loaded or a file is `include`d. + # Need to refresh rules when a package is loaded push!(Base.package_callbacks, pkgid -> refresh_rules()) end @@ -28,13 +28,13 @@ They can be manually triggered by [`refresh_rules`](@ref). When a hook is first registered with `on_new_rule` it is run on all existing rules. """ function on_new_rule(hook_fun, rule_kind) - # get all the existing rules + # apply the hook to the existing rules ret = map(_rule_list(rule_kind)) do method sig = _primal_sig(rule_kind, method) _safe_hook_fun(hook_fun, sig) end - # register hook for new rules + # register hook for new rules -- so all new rules get this function applied push!(_hook_list(rule_kind), hook_fun) return ret end From 27a8592c459b7f335c15cf79b665e328b03b3fda Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 21 Aug 2020 10:14:11 +0100 Subject: [PATCH 30/43] remove type_constraint_equal --- test/ruleset_loading.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/ruleset_loading.jl b/test/ruleset_loading.jl index 81af1c9f8..1239d8c8e 100644 --- a/test/ruleset_loading.jl +++ b/test/ruleset_loading.jl @@ -25,7 +25,6 @@ @testset "_primal_sig" begin _primal_sig = ChainRulesCore._primal_sig - type_constraint_equal(T1, T2) = (T1 <: T2) && (T2 <: T1) @testset "frule" begin @test isequal( # DataType without shared type but with constraint _primal_sig(frule, Tuple{typeof(frule), Any, typeof(*), Int, Vector{Int}}), @@ -56,4 +55,4 @@ ) end end -end \ No newline at end of file +end From 36b641089952c4c46728fff4b2f8231b4362f949 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 21 Aug 2020 10:14:59 +0100 Subject: [PATCH 31/43] Update test/demos/reversediffzero.jl --- test/demos/reversediffzero.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/demos/reversediffzero.jl b/test/demos/reversediffzero.jl index 390c0e3fe..c12fb07c5 100644 --- a/test/demos/reversediffzero.jl +++ b/test/demos/reversediffzero.jl @@ -24,7 +24,7 @@ function Tracked(propagate, primal, tape) return v end -"Maker for inputs (leaves) that don't need to propagate." +"Marker for inputs (leaves) that don't need to propagate." struct NoPropagate end "An input, a Leaf in Nabla terms. No inputs of its own to propagate to." @@ -138,4 +138,3 @@ refresh_rules(); @test derv(quux, 11.1) == (4*(2+3) + 5*(2+3),) end end # module - From e87b845455aa61646a8be9642eb6a7cdd4c0307b Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 21 Aug 2020 12:40:15 +0100 Subject: [PATCH 32/43] Style and comment fixes Co-authored-by: Nick Robinson --- src/ChainRulesCore.jl | 2 +- test/demos/forwarddiffzero.jl | 4 ++-- test/demos/reversediffzero.jl | 10 +++++----- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index a163cb874..23e581c79 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -4,7 +4,7 @@ using MuladdMacro: @muladd export on_new_rule, refresh_rules, clear_new_rule_hooks! # generation tools export frule, rrule # core function -export @scalar_rule, @thunk # defination helper macros +export @scalar_rule, @thunk # definition helper macros export canonicalize, extern, unthunk # differential operations # differentials export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Zero, AbstractZero, AbstractThunk diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl index 4cc1559a2..92edc5334 100644 --- a/test/demos/forwarddiffzero.jl +++ b/test/demos/forwarddiffzero.jl @@ -23,7 +23,7 @@ partial(d::Dual) = d.partial primal(d::Real) = d partial(d::Real) = 0.0 -# needed for ^ to work from having `*` defined +# needed for `^` to work from having `*` defined Base.to_power_type(x::Dual) = x @@ -87,4 +87,4 @@ refresh_rules(); @test derv(quux, 11.1) == 4*(2+3) + 5*(2+3) end -end # module \ No newline at end of file +end # module diff --git a/test/demos/reversediffzero.jl b/test/demos/reversediffzero.jl index c12fb07c5..5694f8af1 100644 --- a/test/demos/reversediffzero.jl +++ b/test/demos/reversediffzero.jl @@ -45,14 +45,14 @@ tape(d) = nothing "we have many inputs grab the tape from the first one that is tracked" get_tape(ds) = something(tape.(ds)...) -"propagate the currently stored partialient back to my inputs." +"propagate the currently stored partial back to my inputs." propagate!(d::Tracked) = d.propagate(d.partial[]) "Accumulate the sensitivity, if the value is being tracked." accum!(d::Tracked, x̄) = d.partial[] += x̄ accum!(d, x̄) = nothing -# needed for ^ to work from having `*` defined +# needed for `^` to work from having `*` defined Base.to_power_type(x::Tracked) = x "What to do when a new rrule is declared" @@ -71,7 +71,7 @@ function define_tracked_overload(sig) y, y_pullback = rrule(args...; kwargs...) the_tape = get_tape(tracked_args) y_tracked = Tracked(y, the_tape) do ȳ - # pull this partialient back and propagate it to the inputs partialient stores + # pull this partial back and propagate it to the inputs partial stores _, ārgs = Iterators.peel(y_pullback(ȳ)) accum!.(tracked_args, ārgs) end @@ -126,10 +126,10 @@ refresh_rules(); @test derv(bar, 1.2) == (3.1,) baz(x) = 2.0 * x^2 + 3.0*x + 1.2 - @test derv(baz, 1.7) == (2*2.0*1.7 + 3.0,) + @test derv(baz, 1.7) == (2 * 2.0 * 1.7 + 3.0,) qux(x) = foo(x) + bar(x) + baz(x) - @test derv(qux, 1.7) == ((2*2.0*1.7 + 3.0) + 3.1 + 2,) + @test derv(qux, 1.7) == ((2 * 2.0 * 1.7 + 3.0) + 3.1 + 2,) function quux(x) y = 2.0*x + 3.0*x From 1d9136684ee345be2bb3a1e2993cbba1750f390a Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 21 Aug 2020 12:40:53 +0100 Subject: [PATCH 33/43] Don't export clear_new_rule_hooks! Co-authored-by: Nick Robinson --- src/ChainRulesCore.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 23e581c79..ca2b0a3ce 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -2,7 +2,7 @@ module ChainRulesCore using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize! using MuladdMacro: @muladd -export on_new_rule, refresh_rules, clear_new_rule_hooks! # generation tools +export on_new_rule, refresh_rules # generation tools export frule, rrule # core function export @scalar_rule, @thunk # definition helper macros export canonicalize, extern, unthunk # differential operations From 4ec4981d9bd86d60425cff9274713d23518a5d6e Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 21 Aug 2020 12:49:05 +0100 Subject: [PATCH 34/43] Update docs/make.jl Co-authored-by: Nick Robinson --- docs/make.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/make.jl b/docs/make.jl index 0e160bf9e..8688bc02e 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -37,7 +37,7 @@ makedocs( "Complex Numbers" => "complex.md", "Deriving Array Rules" => "arrays.md", "Debug Mode" => "debug_mode.md", - "Usage in an AD" => [ + "Usage in AD" => [ "Overview" => "autodiff/overview.md", "Operator Overloading" => "autodiff/operator_overloading.md" ], From baf443180be5d823cdefa3cb97ded8915707af5e Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 21 Aug 2020 12:50:51 +0100 Subject: [PATCH 35/43] move comemnt --- src/ruleset_loading.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/ruleset_loading.jl b/src/ruleset_loading.jl index 5b1306954..9cda91f54 100644 --- a/src/ruleset_loading.jl +++ b/src/ruleset_loading.jl @@ -1,5 +1,4 @@ # Infastructure to support generating overloads from rules. - function __init__() # Need to refresh rules when a package is loaded push!(Base.package_callbacks, pkgid -> refresh_rules()) @@ -52,16 +51,15 @@ If you previously wrong an incorrect hook, you can use this to get rid of the ol """ clear_new_rule_hooks!(rule_kind) = empty!(_hook_list(rule_kind)) - """ _rule_list(frule | rrule) Returns a list of all the methods of the currently defined rules of the given kind. Excluding the fallback rule that returns `nothing` for every input. """ +function _rule_list end +# The fallback rules are the only rules defined in ChainRulesCore & that is how we skip them _rule_list(rule_kind) = (m for m in methods(rule_kind) if m.module != @__MODULE__) -# ^ The fallback rules are the only rules defined in ChainRules core so that is how we skip them. - const LAST_REFRESH_RRULE = Ref(0) From ada48227ede2b61fa685883feb5be41f8b62c4da Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 21 Aug 2020 13:42:12 +0100 Subject: [PATCH 36/43] fix dotpoints in docs --- docs/src/autodiff/overview.md | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/src/autodiff/overview.md b/docs/src/autodiff/overview.md index ce1a97f32..c31c092f1 100644 --- a/docs/src/autodiff/overview.md +++ b/docs/src/autodiff/overview.md @@ -8,15 +8,15 @@ to avoid having to code all your own AD primitives / custom sensitives. There are 3 main ways to access ChainRules rule sets in your AutoDiff system. 1. [Operation Overloading Generation](operator_overloading.html) - - This is primarily intended for operator overloading based AD systems which will generate overloads for primal function based for their overloaded types based on the existance of an `rrule`/`frule`. - - A source code generation based AD can also use this by overloading their transform generating function directly so as not to recursively generate a transform but to just return the rule. - - This does not play nice with Revise.jl, adding or modifying rules in loaded files will not be reflected until a manual refresh, and deleting rules will not be reflected at all. + - This is primarily intended for operator overloading based AD systems which will generate overloads for primal function based for their overloaded types based on the existance of an `rrule`/`frule`. + - A source code generation based AD can also use this by overloading their transform generating function directly so as not to recursively generate a transform but to just return the rule. + - This does not play nice with Revise.jl, adding or modifying rules in loaded files will not be reflected until a manual refresh, and deleting rules will not be reflected at all. 2. Source code tranform based on inserting branches that check of `rrule`/`frule` return `nothing` - - if the `rrule`/`frule` returns a rule result then use it, if it return `nothing` then do normal AD path - - In theory type inference optimizes these branchs out; in practice it may not. - - This is a fairly simple Cassette overdub (or similar) of all calls, and is suitable for overloading based AD or source code transformation. + - if the `rrule`/`frule` returns a rule result then use it, if it return `nothing` then do normal AD path + - In theory type inference optimizes these branchs out; in practice it may not. + - This is a fairly simple Cassette overdub (or similar) of all calls, and is suitable for overloading based AD or source code transformation. 3. Source code transform based on `rrule`/`frule` method-table - - Always use `rrule`/`frule` iff and only if use the rules that exist, else generate normal AD path. - - This avoids having branches in your generated code. - - This requires maintaining your own back-edges - - This is pretty hard-code even by the standard of source code tranformations + - Always use `rrule`/`frule` iff and only if use the rules that exist, else generate normal AD path. + - This avoids having branches in your generated code. + - This requires maintaining your own back-edges + - This is pretty hard-code even by the standard of source code tranformations From cfee7039dd2d68319f5f7e219e5f521ff722042b Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 21 Aug 2020 13:47:02 +0100 Subject: [PATCH 37/43] fix clear rule hooks in tests --- test/ruleset_loading.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/ruleset_loading.jl b/test/ruleset_loading.jl index 1239d8c8e..1a39f0a01 100644 --- a/test/ruleset_loading.jl +++ b/test/ruleset_loading.jl @@ -19,8 +19,8 @@ @test Set(frule_history[end-1:end]) == Set((typeof(+), typeof(-))) @test Set(rrule_history[end-1:end]) == Set((typeof(+), typeof(-))) - clear_new_rule_hooks!(frule) - clear_new_rule_hooks!(rrule) + ChainRulesCore.clear_new_rule_hooks!(frule) + ChainRulesCore.clear_new_rule_hooks!(rrule) end @testset "_primal_sig" begin From 3320fba7cb230b7f12cf6611e692596289651d23 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 21 Aug 2020 13:54:47 +0100 Subject: [PATCH 38/43] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 6a14b179c..0d786877d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.5" +version = "0.9.6" [deps] MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" From fdca95cccfbb123d1825f64077a3c01a849f49b1 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Sat, 22 Aug 2020 11:17:13 +0100 Subject: [PATCH 39/43] Apply suggestions from code review Co-authored-by: willtebbutt Co-authored-by: Nick Robinson --- docs/src/autodiff/operator_overloading.md | 13 ++++++------- docs/src/autodiff/overview.md | 14 +++++++------- src/ruleset_loading.jl | 8 ++++---- test/demos/reversediffzero.jl | 2 +- 4 files changed, 18 insertions(+), 19 deletions(-) diff --git a/docs/src/autodiff/operator_overloading.md b/docs/src/autodiff/operator_overloading.md index be2533f87..b16f92bc9 100644 --- a/docs/src/autodiff/operator_overloading.md +++ b/docs/src/autodiff/operator_overloading.md @@ -1,17 +1,17 @@ # Operator Overloading -The principle interface for using the operator overload generation method is [`on_new_rule`](@ref). +The principal interface for using the operator overload generation method is [`on_new_rule`](@ref). This function allows one to register a hook to be run every time a new rule is defined. The hook receives a signature type-type as input, and generally will use `eval` to define -and overload of AD systems overloaded type. -For example, using the signature type `Tuple{typeof(+), Real, Real}` to define -`+(::DualNumber, ::DualNumber)` as calling the `frule` for `+`. +an overload of an AD system's overloaded type. +For example, using the signature type `Tuple{typeof(+), Real, Real}` to make +`+(::DualNumber, ::DualNumber)` call the `frule` for `+`. A signature type tuple always has the form: `Tuple{typeof(operation), typeof{pos_arg1}, typeof{pos_arg2...}}`, where `pos_arg1` is the first positional argument. -One can dispatch on the signature type, to make rules with argument types your AD does not support not call `eval`; +One can dispatch on the signature type to make rules with argument types your AD does not support not call `eval`; or more simply you can just use conditions for this. -The the hook is automatically triggered whenever a package is loaded. +The hook is automatically triggered whenever a package is loaded. `refresh_rules`(@ref) is used to manually trigger the hook function on any new rules. This is useful for example if new rules are define in the REPL, or if a package defining rules is modified. @@ -45,4 +45,3 @@ $(read(joinpath(@__DIR__,"../../../test/demos/reversediffzero.jl"), String)) ``` """) ```` - diff --git a/docs/src/autodiff/overview.md b/docs/src/autodiff/overview.md index c31c092f1..277873ac9 100644 --- a/docs/src/autodiff/overview.md +++ b/docs/src/autodiff/overview.md @@ -1,22 +1,22 @@ -# Using ChainRules in your AutoDiff system +# Using ChainRules in your AD system This section is for authors of AD systems. -It assumes a pretty solid understanding of Julia, and of automatic differentiation. -It explains how to make use of ChainRule's rule sets, +It assumes a pretty solid understanding of both Julia and automatic differentiation. +It explains how to make use of ChainRule's "rulesets" ([`frule`](@ref)s, [`rrule`](@ref)s,) to avoid having to code all your own AD primitives / custom sensitives. There are 3 main ways to access ChainRules rule sets in your AutoDiff system. 1. [Operation Overloading Generation](operator_overloading.html) - - This is primarily intended for operator overloading based AD systems which will generate overloads for primal function based for their overloaded types based on the existance of an `rrule`/`frule`. + - This is primarily intended for operator overloading based AD systems which will generate overloads for primal functions based for their overloaded types based on the existence of an `rrule`/`frule`. - A source code generation based AD can also use this by overloading their transform generating function directly so as not to recursively generate a transform but to just return the rule. - This does not play nice with Revise.jl, adding or modifying rules in loaded files will not be reflected until a manual refresh, and deleting rules will not be reflected at all. 2. Source code tranform based on inserting branches that check of `rrule`/`frule` return `nothing` - - if the `rrule`/`frule` returns a rule result then use it, if it return `nothing` then do normal AD path + - If the `rrule`/`frule` returns a rule result then use it, if it returns `nothing` then do normal AD path. - In theory type inference optimizes these branchs out; in practice it may not. - This is a fairly simple Cassette overdub (or similar) of all calls, and is suitable for overloading based AD or source code transformation. 3. Source code transform based on `rrule`/`frule` method-table - - Always use `rrule`/`frule` iff and only if use the rules that exist, else generate normal AD path. + - If an applicable `rrule`/`frule` exists in the method table then use it, else generate normal AD path. - This avoids having branches in your generated code. - - This requires maintaining your own back-edges + - This requires maintaining your own back-edges. - This is pretty hard-code even by the standard of source code tranformations diff --git a/src/ruleset_loading.jl b/src/ruleset_loading.jl index 9cda91f54..b16d8f1ca 100644 --- a/src/ruleset_loading.jl +++ b/src/ruleset_loading.jl @@ -16,11 +16,11 @@ _hook_list(::typeof(frule)) = NEW_FRULE_HOOKS Register a `hook` function to run when new rules are defined. The hook receives a signature type-type as input, and generally will use `eval` to define and overload of AD systems overloaded type. -For example, using the signature type `Tuple{typeof(+), Real, Real}` to define -`+(::DualNumber, ::DualNumber)` as calling the `frule` for `+`. +For example, using the signature type `Tuple{typeof(+), Real, Real}` to make +`+(::DualNumber, ::DualNumber)` call the `frule` for `+`. A signature type tuple always has the form: `Tuple{typeof(operation), typeof{pos_arg1}, typeof{pos_arg2...}}`, where `pos_arg1` is the -first positional argument +first positional argument. The hooks are automatically run on new rules whenever a package is loaded. They can be manually triggered by [`refresh_rules`](@ref). @@ -74,7 +74,7 @@ last_refresh(::typeof(rrule)) = LAST_REFRESH_RRULE This triggers all [`on_new_rule`](@ref) hooks to run on any newly defined rules. It is *automatically* run when ever a package is loaded. It can also be manually called to run it directly, for example if a rule was defined -in the REPL or with-in the same file as the AD function. +in the REPL or within the same file as the AD function. """ refresh_rules() = (refresh_rules(frule); refresh_rules(rrule)) function refresh_rules(rule_kind) diff --git a/test/demos/reversediffzero.jl b/test/demos/reversediffzero.jl index 5694f8af1..b10de051c 100644 --- a/test/demos/reversediffzero.jl +++ b/test/demos/reversediffzero.jl @@ -71,7 +71,7 @@ function define_tracked_overload(sig) y, y_pullback = rrule(args...; kwargs...) the_tape = get_tape(tracked_args) y_tracked = Tracked(y, the_tape) do ȳ - # pull this partial back and propagate it to the inputs partial stores + # pull this partial back and propagate it to the input's partial store _, ārgs = Iterators.peel(y_pullback(ȳ)) accum!.(tracked_args, ārgs) end From d97535f595ac03d8025e6404037d174bbf7d9ad6 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 24 Aug 2020 14:01:01 +0100 Subject: [PATCH 40/43] Update docs/src/autodiff/operator_overloading.md --- docs/src/autodiff/operator_overloading.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/autodiff/operator_overloading.md b/docs/src/autodiff/operator_overloading.md index b16f92bc9..eee50b0d2 100644 --- a/docs/src/autodiff/operator_overloading.md +++ b/docs/src/autodiff/operator_overloading.md @@ -7,7 +7,7 @@ an overload of an AD system's overloaded type. For example, using the signature type `Tuple{typeof(+), Real, Real}` to make `+(::DualNumber, ::DualNumber)` call the `frule` for `+`. A signature type tuple always has the form: -`Tuple{typeof(operation), typeof{pos_arg1}, typeof{pos_arg2...}}`, where `pos_arg1` is the +`Tuple{typeof(operation), typeof{pos_arg1}, typeof{pos_arg2}, ...}`, where `pos_arg1` is the first positional argument. One can dispatch on the signature type to make rules with argument types your AD does not support not call `eval`; or more simply you can just use conditions for this. From 7d325098e13920c8f980764f7fe04ea20bc8f0d5 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 24 Aug 2020 14:54:24 +0100 Subject: [PATCH 41/43] More docs on generation --- docs/src/autodiff/operator_overloading.md | 37 +++++++++++++++++++++-- test/demos/forwarddiffzero.jl | 2 +- test/demos/reversediffzero.jl | 2 +- 3 files changed, 37 insertions(+), 4 deletions(-) diff --git a/docs/src/autodiff/operator_overloading.md b/docs/src/autodiff/operator_overloading.md index eee50b0d2..2ed85c795 100644 --- a/docs/src/autodiff/operator_overloading.md +++ b/docs/src/autodiff/operator_overloading.md @@ -11,11 +11,44 @@ A signature type tuple always has the form: first positional argument. One can dispatch on the signature type to make rules with argument types your AD does not support not call `eval`; or more simply you can just use conditions for this. -The hook is automatically triggered whenever a package is loaded. +For example if your AD only supports `AbstractMatrix{Float64}` and `Float64` inputs you might write: +```julia +const ACCEPT_TYPE = Union{Float64, AbstractMatrix{Float64}} +function define_overload(sig::Type{<:Tuple{F, Vararg{ACCEPT_TYPE}}) where F + @eval quote + # ... + end +end +define_overload(::Any) = nothing # don't do anything for any other signature + +on_new_rule(frule, define_overload) +``` + +or you might write: +```julia +const ACCEPT_TYPES = (Float64, AbstractMatrix{Float64}) +function define_overload(sig) where F + sig = Base.unwrap_unionall(sig) # not really handling most UnionAll, + opT, argTs = Iterators.peel(sig.parameters) + all(any(acceptT<: argT for acceptT in ACCEPT_TYPES) for argT in argTs) || return + @eval quote + # ... + end +end + +on_new_rule(frule, define_overload) +``` -`refresh_rules`(@ref) is used to manually trigger the hook function on any new rules. +The generation of overloaded code is the responsibility of the AD implementor. +Packages like [ExprTools.jl](https://github.com/invenia/ExprTools.jl) can be helpful for this. +Its generally fairly simple, though can become complex if you need to handle complicated type-constraints. +Examples are shown below. + +The hook is automatically triggered whenever a package is loaded. +It can also be triggers manually using `refresh_rules`(@ref). This is useful for example if new rules are define in the REPL, or if a package defining rules is modified. (Revise.jl will not automatically trigger). +When the rules are refreshed (automatically or manually), the hooks are only triggered on new/modified rules; not ones that have already had the hooks triggered on. `clear_new_rule_hooks!`(@ref) clears all registered hooks. It is useful to undo [`on_new_rule`] hook registration if you are iteratively developing your overload generation function. diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl index 92edc5334..59a2429ad 100644 --- a/test/demos/forwarddiffzero.jl +++ b/test/demos/forwarddiffzero.jl @@ -28,7 +28,7 @@ Base.to_power_type(x::Dual) = x function define_dual_overload(sig) - sig = Base.unwrap_unionall(sig) + sig = Base.unwrap_unionall(sig) # Not really handling most UnionAlls opT, argTs = Iterators.peel(sig.parameters) fieldcount(opT) == 0 || return # not handling functors all(Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops. diff --git a/test/demos/reversediffzero.jl b/test/demos/reversediffzero.jl index b10de051c..adeebef9a 100644 --- a/test/demos/reversediffzero.jl +++ b/test/demos/reversediffzero.jl @@ -57,7 +57,7 @@ Base.to_power_type(x::Tracked) = x "What to do when a new rrule is declared" function define_tracked_overload(sig) - sig = Base.unwrap_unionall(sig) + sig = Base.unwrap_unionall(sig) # not really handling most UnionAll opT, argTs = Iterators.peel(sig.parameters) fieldcount(opT) == 0 || return # not handling functors all(Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops. From ecd2bb6cf7702376869a25daa5ab8d52096640dc Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 24 Aug 2020 15:58:46 +0100 Subject: [PATCH 42/43] test clear hooks --- test/ruleset_loading.jl | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/test/ruleset_loading.jl b/test/ruleset_loading.jl index 1a39f0a01..e1743f0f9 100644 --- a/test/ruleset_loading.jl +++ b/test/ruleset_loading.jl @@ -10,17 +10,31 @@ op = sig.parameters[1] push!(rrule_history, op) end + + @testset "new rules hit the hooks" begin + # Now define some rules + @scalar_rule x + y (1, 1) + @scalar_rule x - y (1, -1) + refresh_rules() + + @test Set(frule_history[end-1:end]) == Set((typeof(+), typeof(-))) + @test Set(rrule_history[end-1:end]) == Set((typeof(+), typeof(-))) + end - # Now define some rules - @scalar_rule x + y (1, 1) - @scalar_rule x - y (1, -1) - refresh_rules() + @testset "# Make sure nothing happens anymore once we clear the hooks" begin + ChainRulesCore.clear_new_rule_hooks!(frule) + ChainRulesCore.clear_new_rule_hooks!(rrule) + + old_frule_history = copy(frule_history) + old_rrule_history = copy(rrule_history) + + @scalar_rule sin(x) cos(x) + refresh_rules() + + @test old_rrule_history == rrule_history + @test old_frule_history == frule_history + end - @test Set(frule_history[end-1:end]) == Set((typeof(+), typeof(-))) - @test Set(rrule_history[end-1:end]) == Set((typeof(+), typeof(-))) - - ChainRulesCore.clear_new_rule_hooks!(frule) - ChainRulesCore.clear_new_rule_hooks!(rrule) end @testset "_primal_sig" begin From eccb8940af9ec69395361b5b5a233f5e94a4e65e Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 24 Aug 2020 15:58:58 +0100 Subject: [PATCH 43/43] wrap up code review --- docs/src/autodiff/overview.md | 2 +- src/rules.jl | 2 +- src/ruleset_loading.jl | 21 +++++++++++++-------- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/docs/src/autodiff/overview.md b/docs/src/autodiff/overview.md index 277873ac9..2638d6382 100644 --- a/docs/src/autodiff/overview.md +++ b/docs/src/autodiff/overview.md @@ -19,4 +19,4 @@ There are 3 main ways to access ChainRules rule sets in your AutoDiff system. - If an applicable `rrule`/`frule` exists in the method table then use it, else generate normal AD path. - This avoids having branches in your generated code. - This requires maintaining your own back-edges. - - This is pretty hard-code even by the standard of source code tranformations + - This is pretty hardcore even by the standard of source code tranformations. diff --git a/src/rules.jl b/src/rules.jl index b83cda45a..8074a00ed 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -91,4 +91,4 @@ true See also: [`frule`](@ref), [`@scalar_rule`](@ref) """ -rrule(::Any, ::Vararg{Any}; kwargs...) = nothing \ No newline at end of file +rrule(::Any, ::Vararg{Any}; kwargs...) = nothing diff --git a/src/ruleset_loading.jl b/src/ruleset_loading.jl index b16d8f1ca..31b289d8e 100644 --- a/src/ruleset_loading.jl +++ b/src/ruleset_loading.jl @@ -4,22 +4,22 @@ function __init__() push!(Base.package_callbacks, pkgid -> refresh_rules()) end - -const NEW_RRULE_HOOKS = Function[] -const NEW_FRULE_HOOKS = Function[] -_hook_list(::typeof(rrule)) = NEW_RRULE_HOOKS -_hook_list(::typeof(frule)) = NEW_FRULE_HOOKS +# Holds all the hook functions that are invokes when a new rule is defined +const RRULE_DEFINITION_HOOKS = Function[] +const FRULE_DEFINITION_HOOKS = Function[] +_hook_list(::typeof(rrule)) = RRULE_DEFINITION_HOOKS +_hook_list(::typeof(frule)) = FRULE_DEFINITION_HOOKS """ on_new_rule(hook, frule | rrule) Register a `hook` function to run when new rules are defined. The hook receives a signature type-type as input, and generally will use `eval` to define -and overload of AD systems overloaded type. +an overload of an AD system's overloaded type For example, using the signature type `Tuple{typeof(+), Real, Real}` to make `+(::DualNumber, ::DualNumber)` call the `frule` for `+`. A signature type tuple always has the form: -`Tuple{typeof(operation), typeof{pos_arg1}, typeof{pos_arg2...}}`, where `pos_arg1` is the +`Tuple{typeof(operation), typeof{pos_arg1}, typeof{pos_arg2}...}`, where `pos_arg1` is the first positional argument. The hooks are automatically run on new rules whenever a package is loaded. @@ -76,8 +76,13 @@ It is *automatically* run when ever a package is loaded. It can also be manually called to run it directly, for example if a rule was defined in the REPL or within the same file as the AD function. """ -refresh_rules() = (refresh_rules(frule); refresh_rules(rrule)) +function refresh_rules() + refresh_rules(frule); + refresh_rules(rrule) +end + function refresh_rules(rule_kind) + isempty(_rule_list(rule_kind)) && return # if no hooks, exit early, nothing to run already_done_world_age = last_refresh(rule_kind)[] for method in _rule_list(rule_kind) _defined_world(method) < already_done_world_age && continue