diff --git a/Project.toml b/Project.toml index 6a14b179c..0d786877d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.5" +version = "0.9.6" [deps] MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index c58a79af6..da934f718 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -31,9 +31,9 @@ version = "0.8.2" [[Documenter]] deps = ["Base64", "Dates", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] -git-tree-sha1 = "1c593d1efa27437ed9dd365d1143c594b563e138" +git-tree-sha1 = "fb1ff838470573adc15c71ba79f8d31328f035da" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "0.25.1" +version = "0.25.2" [[DocumenterTools]] deps = ["Base64", "DocStringExtensions", "Documenter", "FileWatching", "LibGit2", "Sass"] @@ -78,9 +78,9 @@ version = "0.2.2" [[Parsers]] deps = ["Dates", "Test"] -git-tree-sha1 = "10134f2ee0b1978ae7752c41306e131a684e1f06" +git-tree-sha1 = "8077624b3c450b15c087944363606a6ba12f925e" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "1.0.7" +version = "1.0.10" [[Pkg]] deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] @@ -91,7 +91,7 @@ deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" [[REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets"] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" [[Random]] diff --git a/docs/Project.toml b/docs/Project.toml index 66770708a..a39e1b1da 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -2,6 +2,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterTools = "35a29f4d-8980-5a13-9543-d66fff28ecb8" +Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" [compat] Documenter = "0.25" diff --git a/docs/make.jl b/docs/make.jl index 8d1012ed1..8688bc02e 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,6 +1,7 @@ using ChainRulesCore using Documenter using DocumenterTools: Themes +using Markdown DocMeta.setdocmeta!( ChainRulesCore, @@ -36,6 +37,10 @@ makedocs( "Complex Numbers" => "complex.md", "Deriving Array Rules" => "arrays.md", "Debug Mode" => "debug_mode.md", + "Usage in AD" => [ + "Overview" => "autodiff/overview.md", + "Operator Overloading" => "autodiff/operator_overloading.md" + ], "Design" => [ "Many Differential Types" => "design/many_differentials.md", ], diff --git a/docs/src/api.md b/docs/src/api.md index 0aa9d44d2..ad34569e9 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -27,8 +27,16 @@ Pages = [ Private = false ``` +## Ruleset Loading +```@autodocs +Modules = [ChainRulesCore] +Pages = ["ruleset_loading.jl"] +Private = false +``` + ## Internal ```@docs ChainRulesCore.AbstractDifferential ChainRulesCore.debug_mode +ChainRulesCore.clear_new_rule_hooks! ``` diff --git a/docs/src/autodiff/operator_overloading.md b/docs/src/autodiff/operator_overloading.md new file mode 100644 index 000000000..2ed85c795 --- /dev/null +++ b/docs/src/autodiff/operator_overloading.md @@ -0,0 +1,80 @@ +# Operator Overloading + +The principal interface for using the operator overload generation method is [`on_new_rule`](@ref). +This function allows one to register a hook to be run every time a new rule is defined. +The hook receives a signature type-type as input, and generally will use `eval` to define +an overload of an AD system's overloaded type. +For example, using the signature type `Tuple{typeof(+), Real, Real}` to make +`+(::DualNumber, ::DualNumber)` call the `frule` for `+`. +A signature type tuple always has the form: +`Tuple{typeof(operation), typeof{pos_arg1}, typeof{pos_arg2}, ...}`, where `pos_arg1` is the +first positional argument. +One can dispatch on the signature type to make rules with argument types your AD does not support not call `eval`; +or more simply you can just use conditions for this. +For example if your AD only supports `AbstractMatrix{Float64}` and `Float64` inputs you might write: +```julia +const ACCEPT_TYPE = Union{Float64, AbstractMatrix{Float64}} +function define_overload(sig::Type{<:Tuple{F, Vararg{ACCEPT_TYPE}}) where F + @eval quote + # ... + end +end +define_overload(::Any) = nothing # don't do anything for any other signature + +on_new_rule(frule, define_overload) +``` + +or you might write: +```julia +const ACCEPT_TYPES = (Float64, AbstractMatrix{Float64}) +function define_overload(sig) where F + sig = Base.unwrap_unionall(sig) # not really handling most UnionAll, + opT, argTs = Iterators.peel(sig.parameters) + all(any(acceptT<: argT for acceptT in ACCEPT_TYPES) for argT in argTs) || return + @eval quote + # ... + end +end + +on_new_rule(frule, define_overload) +``` + +The generation of overloaded code is the responsibility of the AD implementor. +Packages like [ExprTools.jl](https://github.com/invenia/ExprTools.jl) can be helpful for this. +Its generally fairly simple, though can become complex if you need to handle complicated type-constraints. +Examples are shown below. + +The hook is automatically triggered whenever a package is loaded. +It can also be triggers manually using `refresh_rules`(@ref). +This is useful for example if new rules are define in the REPL, or if a package defining rules is modified. +(Revise.jl will not automatically trigger). +When the rules are refreshed (automatically or manually), the hooks are only triggered on new/modified rules; not ones that have already had the hooks triggered on. + +`clear_new_rule_hooks!`(@ref) clears all registered hooks. +It is useful to undo [`on_new_rule`] hook registration if you are iteratively developing your overload generation function. + +## Examples + +### ForwardDiffZero +The overload generation hook in this example is: `define_dual_overload`. + +````@eval +using Markdown +Markdown.parse(""" +```julia +$(read(joinpath(@__DIR__,"../../../test/demos/forwarddiffzero.jl"), String)) +``` +""") +```` + +### ReverseDiffZero +The overload generation hook in this example is: `define_tracked_overload`. + +````@eval +using Markdown +Markdown.parse(""" +```julia +$(read(joinpath(@__DIR__,"../../../test/demos/reversediffzero.jl"), String)) +``` +""") +```` diff --git a/docs/src/autodiff/overview.md b/docs/src/autodiff/overview.md new file mode 100644 index 000000000..2638d6382 --- /dev/null +++ b/docs/src/autodiff/overview.md @@ -0,0 +1,22 @@ +# Using ChainRules in your AD system + +This section is for authors of AD systems. +It assumes a pretty solid understanding of both Julia and automatic differentiation. +It explains how to make use of ChainRule's "rulesets" ([`frule`](@ref)s, [`rrule`](@ref)s,) +to avoid having to code all your own AD primitives / custom sensitives. + +There are 3 main ways to access ChainRules rule sets in your AutoDiff system. + +1. [Operation Overloading Generation](operator_overloading.html) + - This is primarily intended for operator overloading based AD systems which will generate overloads for primal functions based for their overloaded types based on the existence of an `rrule`/`frule`. + - A source code generation based AD can also use this by overloading their transform generating function directly so as not to recursively generate a transform but to just return the rule. + - This does not play nice with Revise.jl, adding or modifying rules in loaded files will not be reflected until a manual refresh, and deleting rules will not be reflected at all. +2. Source code tranform based on inserting branches that check of `rrule`/`frule` return `nothing` + - If the `rrule`/`frule` returns a rule result then use it, if it returns `nothing` then do normal AD path. + - In theory type inference optimizes these branchs out; in practice it may not. + - This is a fairly simple Cassette overdub (or similar) of all calls, and is suitable for overloading based AD or source code transformation. +3. Source code transform based on `rrule`/`frule` method-table + - If an applicable `rrule`/`frule` exists in the method table then use it, else generate normal AD path. + - This avoids having branches in your generated code. + - This requires maintaining your own back-edges. + - This is pretty hardcore even by the standard of source code tranformations. diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 73b6231d3..ca2b0a3ce 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -1,9 +1,12 @@ module ChainRulesCore using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize! +using MuladdMacro: @muladd -export frule, rrule -export @scalar_rule, @thunk -export canonicalize, extern, unthunk +export on_new_rule, refresh_rules # generation tools +export frule, rrule # core function +export @scalar_rule, @thunk # definition helper macros +export canonicalize, extern, unthunk # differential operations +# differentials export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Zero, AbstractZero, AbstractThunk export NO_FIELDS @@ -20,5 +23,6 @@ include("differential_arithmetic.jl") include("rules.jl") include("rule_definition_tools.jl") +include("ruleset_loading.jl") end # module diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 83330e23b..2675c7598 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -1,6 +1,4 @@ # These are some macros (and supporting functions) to make it easier to define rules. -using MuladdMacro: @muladd - """ @scalar_rule(f(x₁, x₂, ...), @setup(statement₁, statement₂, ...), diff --git a/src/rules.jl b/src/rules.jl index a55e6d4bf..8074a00ed 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -1,7 +1,3 @@ -##### -##### `frule`/`rrule` -##### - """ frule((Δf, Δx...), f, x...) diff --git a/src/ruleset_loading.jl b/src/ruleset_loading.jl new file mode 100644 index 000000000..31b289d8e --- /dev/null +++ b/src/ruleset_loading.jl @@ -0,0 +1,141 @@ +# Infastructure to support generating overloads from rules. +function __init__() + # Need to refresh rules when a package is loaded + push!(Base.package_callbacks, pkgid -> refresh_rules()) +end + +# Holds all the hook functions that are invokes when a new rule is defined +const RRULE_DEFINITION_HOOKS = Function[] +const FRULE_DEFINITION_HOOKS = Function[] +_hook_list(::typeof(rrule)) = RRULE_DEFINITION_HOOKS +_hook_list(::typeof(frule)) = FRULE_DEFINITION_HOOKS + +""" + on_new_rule(hook, frule | rrule) + +Register a `hook` function to run when new rules are defined. +The hook receives a signature type-type as input, and generally will use `eval` to define +an overload of an AD system's overloaded type +For example, using the signature type `Tuple{typeof(+), Real, Real}` to make +`+(::DualNumber, ::DualNumber)` call the `frule` for `+`. +A signature type tuple always has the form: +`Tuple{typeof(operation), typeof{pos_arg1}, typeof{pos_arg2}...}`, where `pos_arg1` is the +first positional argument. + +The hooks are automatically run on new rules whenever a package is loaded. +They can be manually triggered by [`refresh_rules`](@ref). +When a hook is first registered with `on_new_rule` it is run on all existing rules. +""" +function on_new_rule(hook_fun, rule_kind) + # apply the hook to the existing rules + ret = map(_rule_list(rule_kind)) do method + sig = _primal_sig(rule_kind, method) + _safe_hook_fun(hook_fun, sig) + end + + # register hook for new rules -- so all new rules get this function applied + push!(_hook_list(rule_kind), hook_fun) + return ret +end + +""" + clear_new_rule_hooks!(frule|rrule) + +Clears all hooks that were registered with corresponding [`on_new_rule`](@ref). +This is useful for while working interactively to define your rule generating hooks. +If you previously wrong an incorrect hook, you can use this to get rid of the old one. + +!!! warning + This absolutely should not be used in a package, as it will break any other AD system + using the rule hooks that might happen to be loaded. +""" +clear_new_rule_hooks!(rule_kind) = empty!(_hook_list(rule_kind)) + +""" + _rule_list(frule | rrule) + +Returns a list of all the methods of the currently defined rules of the given kind. +Excluding the fallback rule that returns `nothing` for every input. +""" +function _rule_list end +# The fallback rules are the only rules defined in ChainRulesCore & that is how we skip them +_rule_list(rule_kind) = (m for m in methods(rule_kind) if m.module != @__MODULE__) + + +const LAST_REFRESH_RRULE = Ref(0) +const LAST_REFRESH_FRULE = Ref(0) +last_refresh(::typeof(frule)) = LAST_REFRESH_FRULE +last_refresh(::typeof(rrule)) = LAST_REFRESH_RRULE + +""" + refresh_rules() + refresh_rules(frule | rrule) + +This triggers all [`on_new_rule`](@ref) hooks to run on any newly defined rules. +It is *automatically* run when ever a package is loaded. +It can also be manually called to run it directly, for example if a rule was defined +in the REPL or within the same file as the AD function. +""" +function refresh_rules() + refresh_rules(frule); + refresh_rules(rrule) +end + +function refresh_rules(rule_kind) + isempty(_rule_list(rule_kind)) && return # if no hooks, exit early, nothing to run + already_done_world_age = last_refresh(rule_kind)[] + for method in _rule_list(rule_kind) + _defined_world(method) < already_done_world_age && continue + sig = _primal_sig(rule_kind, method) + _trigger_new_rule_hooks(rule_kind, sig) + end + + last_refresh(rule_kind)[] = _current_world() + return nothing +end + +@static if VERSION >= v"1.2" + _current_world() = Base.get_world_counter() + _defined_world(method) = method.primary_world +else + _current_world() = ccall(:jl_get_world_counter, UInt, ()) + _defined_world(method) = method.min_world +end + +""" + _primal_sig(frule|rule, rule_method | rule_sig) + +Returns the signature as a `Tuple{function_type, arg1_type, arg2_type,...}`. +""" +_primal_sig(rule_kind, method::Method) = _primal_sig(rule_kind, method.sig) +function _primal_sig(::typeof(frule), rule_sig::DataType) + @assert rule_sig.parameters[1] == typeof(frule) + # need to skip frule and the deriviative info, so starting from the 3rd + return Tuple{rule_sig.parameters[3:end]...} +end +function _primal_sig(::typeof(rrule), rule_sig::DataType) + @assert rule_sig.parameters[1] == typeof(rrule) + # need to skip rrule so starting from the 2rd + return Tuple{rule_sig.parameters[2:end]...} +end +function _primal_sig(rule_kind, rule_sig::UnionAll) + # This looks a lot like Base.unwrap_unionall and Base.rewrap_unionall, but using those + # seems not to work + p_sig = _primal_sig(rule_kind, rule_sig.body) + return UnionAll(rule_sig.var, p_sig) +end + + +function _trigger_new_rule_hooks(rule_kind, sig) + for hook_fun in _hook_list(rule_kind) + _safe_hook_fun(hook_fun, sig) + end +end + +function _safe_hook_fun(hook_fun, sig) + try + hook_fun(sig) + catch err + @error "Error triggering hook" hook_fun sig exception=err + end +end diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl new file mode 100644 index 000000000..59a2429ad --- /dev/null +++ b/test/demos/forwarddiffzero.jl @@ -0,0 +1,90 @@ +"The simplest viable forward mode a AD, only supports `Float64`" +module ForwardDiffZero +using ChainRulesCore +using Test + +######################################### +# Initial rule setup +@scalar_rule x + y (1, 1) +@scalar_rule x - y (1, -1) +########################## +# Define the AD + +# Note that we never directly define Dual Number Arithmetic on Dual numbers +# instead it is automatically defined from the `frules` +struct Dual <: Real + primal::Float64 + partial::Float64 +end + +primal(d::Dual) = d.primal +partial(d::Dual) = d.partial + +primal(d::Real) = d +partial(d::Real) = 0.0 + +# needed for `^` to work from having `*` defined +Base.to_power_type(x::Dual) = x + + +function define_dual_overload(sig) + sig = Base.unwrap_unionall(sig) # Not really handling most UnionAlls + opT, argTs = Iterators.peel(sig.parameters) + fieldcount(opT) == 0 || return # not handling functors + all(Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops. + + N = length(sig.parameters) - 1 # skip the op + fdef = quote + # we use the function call overloading form as it lets us avoid namespacing issues + # as we can directly interpolate the function type into to the AST. + function (op::$opT)(dual_args::Vararg{Union{Dual, Float64}, $N}; kwargs...) + ȧrgs = (NO_FIELDS, partial.(dual_args)...) + args = (op, primal.(dual_args)...) + y, ẏ = frule(ȧrgs, args...; kwargs...) + return Dual(y, ẏ) # if y, ẏ are not `Float64` this will error. + end + end + eval(fdef) +end + +# !Important!: Attach the define function to the `on_new_rule` hook +on_new_rule(define_dual_overload, frule) + +"Do a calculus. `f` should have a single input." +function derv(f, arg) + duals = Dual(arg, one(arg)) + return partial(f(duals...)) +end + +# End AD definition +################################ + +# add a rule later also +function ChainRulesCore.frule((_, Δx, Δy), ::typeof(*), x::Number, y::Number) + return (x * y, Δx * y + x * Δy) +end + +# Manual refresh needed as new rule added in same file as AD after the `on_new_rule` call +refresh_rules(); + +@testset "ForwardDiffZero" begin + foo(x) = x + x + @test derv(foo, 1.6) == 2 + + bar(x) = x + 2.1 * x + @test derv(bar, 1.2) == 3.1 + + baz(x) = 2.0 * x^2 + 3.0*x + 1.2 + @test derv(baz, 1.7) == 2*2.0*1.7 + 3.0 + + qux(x) = foo(x) + bar(x) + baz(x) + @test derv(qux, 1.7) == (2*2.0*1.7 + 3.0) + 3.1 + 2 + + function quux(x) + y = 2.0*x + 3.0*x + return 4.0*y + 5.0*y + end + @test derv(quux, 11.1) == 4*(2+3) + 5*(2+3) +end + +end # module diff --git a/test/demos/reversediffzero.jl b/test/demos/reversediffzero.jl new file mode 100644 index 000000000..adeebef9a --- /dev/null +++ b/test/demos/reversediffzero.jl @@ -0,0 +1,140 @@ +"The simplest viable reverse mode a AD, only supports `Float64`" +module ReverseDiffZero +using ChainRulesCore +using Test + +######################################### +# Initial rule setup +@scalar_rule x + y (1, 1) +@scalar_rule x - y (1, -1) +########################## +#Define the AD + +struct Tracked{F} <: Real + propagate::F + primal::Float64 + tape::Vector{Tracked} # a reference to a shared tape + partial::Base.RefValue{Float64} # current accumulated sensitivity +end + +"An intermediate value, a Branch in Nabla terms." +function Tracked(propagate, primal, tape) + v = Tracked(propagate, primal, tape, Ref(zero(primal))) + push!(tape, v) + return v +end + +"Marker for inputs (leaves) that don't need to propagate." +struct NoPropagate end + +"An input, a Leaf in Nabla terms. No inputs of its own to propagate to." +function Tracked(primal, tape) + # don't actually need to put these on the tape, since they don't need to propagate + return Tracked(NoPropagate(), primal, tape, Ref(zero(primal))) +end + +primal(d::Tracked) = d.primal +primal(d) = d + +partial(d::Tracked) = d.partial[] +partial(d) = nothing + +tape(d::Tracked) = d.tape +tape(d) = nothing + +"we have many inputs grab the tape from the first one that is tracked" +get_tape(ds) = something(tape.(ds)...) + +"propagate the currently stored partial back to my inputs." +propagate!(d::Tracked) = d.propagate(d.partial[]) + +"Accumulate the sensitivity, if the value is being tracked." +accum!(d::Tracked, x̄) = d.partial[] += x̄ +accum!(d, x̄) = nothing + +# needed for `^` to work from having `*` defined +Base.to_power_type(x::Tracked) = x + +"What to do when a new rrule is declared" +function define_tracked_overload(sig) + sig = Base.unwrap_unionall(sig) # not really handling most UnionAll + opT, argTs = Iterators.peel(sig.parameters) + fieldcount(opT) == 0 || return # not handling functors + all(Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops. + + N = length(sig.parameters) - 1 # skip the op + fdef = quote + # we use the function call overloading form as it lets us avoid namespacing issues + # as we can directly interpolate the function type into to the AST. + function (op::$opT)(tracked_args::Vararg{Union{Tracked, Float64}, $N}; kwargs...) + args = (op, primal.(tracked_args)...) + y, y_pullback = rrule(args...; kwargs...) + the_tape = get_tape(tracked_args) + y_tracked = Tracked(y, the_tape) do ȳ + # pull this partial back and propagate it to the input's partial store + _, ārgs = Iterators.peel(y_pullback(ȳ)) + accum!.(tracked_args, ārgs) + end + return y_tracked + end + end + eval(fdef) +end + +# !Important!: Attach the define function to the `on_new_rule` hook +on_new_rule(define_tracked_overload, rrule) + +"Do a calculus. `f` should have a single output." +function derv(f, args::Vararg; kwargs...) + the_tape = Vector{Tracked}() + tracked_inputs = Tracked.(args, Ref(the_tape)) + tracked_output = f(tracked_inputs...; kwargs...) + @assert tape(tracked_output) === the_tape + + # Now the backward pass + out = primal(tracked_output) + ōut = one(out) + accum!(tracked_output, ōut) + # By going down the tape backwards we know we will have fully accumulated partials + # before propagating them onwards + for op in reverse(the_tape) + propagate!(op) + end + return partial.(tracked_inputs) +end + +# End AD definition +################################ + +# add a rule later also +function ChainRulesCore.rrule(::typeof(*), x::Number, y::Number) + function times_pullback(ΔΩ) + # we will use thunks here to show we handle them fine. + return (NO_FIELDS, @thunk(ΔΩ * y'), @thunk(x' * ΔΩ)) + end + return x * y, times_pullback +end + +# Manual refresh needed as new rule added in same file as AD after the `on_new_rule` call +refresh_rules(); + +@testset "ReversedDiffZero" begin + foo(x) = x + x + @test derv(foo, 1.6) == (2.0,) + + bar(x) = x + 2.1 * x + @test derv(bar, 1.2) == (3.1,) + + baz(x) = 2.0 * x^2 + 3.0*x + 1.2 + @test derv(baz, 1.7) == (2 * 2.0 * 1.7 + 3.0,) + + qux(x) = foo(x) + bar(x) + baz(x) + @test derv(qux, 1.7) == ((2 * 2.0 * 1.7 + 3.0) + 3.1 + 2,) + + function quux(x) + y = 2.0*x + 3.0*x + return 4.0*y + 5.0*y + end + @test derv(quux, 11.1) == (4*(2+3) + 5*(2+3),) +end +end # module diff --git a/test/ruleset_loading.jl b/test/ruleset_loading.jl new file mode 100644 index 000000000..e1743f0f9 --- /dev/null +++ b/test/ruleset_loading.jl @@ -0,0 +1,72 @@ +@testset "ruleset_loading.jl" begin + @testset "on_new_rule" begin + frule_history = [] + rrule_history = [] + on_new_rule(frule) do sig + op = sig.parameters[1] + push!(frule_history, op) + end + on_new_rule(rrule) do sig + op = sig.parameters[1] + push!(rrule_history, op) + end + + @testset "new rules hit the hooks" begin + # Now define some rules + @scalar_rule x + y (1, 1) + @scalar_rule x - y (1, -1) + refresh_rules() + + @test Set(frule_history[end-1:end]) == Set((typeof(+), typeof(-))) + @test Set(rrule_history[end-1:end]) == Set((typeof(+), typeof(-))) + end + + @testset "# Make sure nothing happens anymore once we clear the hooks" begin + ChainRulesCore.clear_new_rule_hooks!(frule) + ChainRulesCore.clear_new_rule_hooks!(rrule) + + old_frule_history = copy(frule_history) + old_rrule_history = copy(rrule_history) + + @scalar_rule sin(x) cos(x) + refresh_rules() + + @test old_rrule_history == rrule_history + @test old_frule_history == frule_history + end + + end + + @testset "_primal_sig" begin + _primal_sig = ChainRulesCore._primal_sig + @testset "frule" begin + @test isequal( # DataType without shared type but with constraint + _primal_sig(frule, Tuple{typeof(frule), Any, typeof(*), Int, Vector{Int}}), + Tuple{typeof(*), Int, Vector{Int}} + ) + @test isequal( # UnionAall without shared type but with constraint + _primal_sig(frule, Tuple{typeof(frule), Any, typeof(*), T, Int} where T<:Real), + Tuple{typeof(*), T, Int} where T<:Real + ) + @test isequal( # UnionAall with share type + _primal_sig(frule, Tuple{typeof(frule), Any, typeof(*), T, Vector{T}} where T), + Tuple{typeof(*), T, Vector{T}} where T + ) + end + + @testset "rrule" begin + @test isequal( # DataType without shared type but with constraint + _primal_sig(rrule, Tuple{typeof(rrule), typeof(*), Int, Vector{Int}}), + Tuple{typeof(*), Int, Vector{Int}} + ) + @test isequal( # UnionAall without shared type but with constraint + _primal_sig(rrule, Tuple{typeof(rrule), typeof(*), T, Int} where T<:Real), + Tuple{typeof(*), T, Int} where T<:Real + ) + @test isequal( # UnionAall with share type + _primal_sig(rrule, Tuple{typeof(rrule), typeof(*), T, Vector{T}} where T), + Tuple{typeof(*), T, Vector{T}} where T + ) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 5d58db5c9..8f995b354 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,5 +13,12 @@ using Test include("differentials/composite.jl") end + include("ruleset_loading.jl") include("rules.jl") + + + @testset "demos" begin + include("demos/forwarddiffzero.jl") + include("demos/reversediffzero.jl") + end end