From f6a71790c66491246e0a8a8ae1c9028b41fe2b9a Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 21 Jul 2020 19:25:37 +0100 Subject: [PATCH] add ForwardDiffZero as an API integration test --- test/demos/forwarddiffzero.jl | 76 +++++++++++++++++++++++++++++++++++ test/runtests.jl | 4 ++ 2 files changed, 80 insertions(+) create mode 100644 test/demos/forwarddiffzero.jl diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl new file mode 100644 index 000000000..f6441e00a --- /dev/null +++ b/test/demos/forwarddiffzero.jl @@ -0,0 +1,76 @@ +"The simplest viable forward mode a AD, only supports `Float64`" +module ForwardDiffZero + +using ChainRulesCore +using Test + +struct Dual <: Real + primal::Float64 + diff::Float64 +end + +primal(d::Dual) = d.primal +diff(d::Dual) = d.diff + +primal(d::Real) = d +diff(d::Real) = 0.0 + +# needed for ^ to work from having `*` defined +Base.to_power_type(x::Dual) = x + + +function define_dual_overload(sig, ast) + opT, argTs = Iterators.peel(sig.parameters) + fieldcount(opT) == 0 || return # not handling functors + all(Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops. + op = opT.instance + opname = :($(parentmodule(op)).$(nameof(op))) + N = length(sig.parameters) - 1 # skip the op + fdef = quote + function $opname(dual_args::Vararg{Union{Dual, Float64},$N}; kwargs...) + ȧrgs = (NO_FIELDS, diff.(dual_args)...) + args = ($opname, primal.(dual_args)...) + res = frule(ȧrgs, args...; kwargs...) + res === nothing && error("Apparently no rule for $($sig)), but we really thought there was, args=($args)") + y, ẏ = res + return Dual(y, ẏ) # if y, ẏ are not `Float64` this will error. + end + end + # @show fdef + @eval $fdef +end + +######################################### +# Initial rule setup +@scalar_rule x + y (One(), One()) +@scalar_rule x - y (One(), -1) + +on_new_rule(frule) do sig, ast + return define_dual_overload(sig, ast) +end + +# add a rule later also +@frule function ChainRulesCore.frule((_, Δx, Δy), ::typeof(*), x::Number, y::Number) + return (x * y, (Δx * y + x * Δy)) +end + +function derv(f, args...) + duals = Dual.(args, one.(args)) + return diff(f(duals...)) +end + +@testset "ForwardDiffZero" begin + foo(x) = x + x + @test derv(foo, 1.6) == 2 + + bar(x) = x + 2.1 * x + @test derv(bar, 1.2) == 3.1 + + baz(x) = 2.0 * x^2 + 3.0*x + 1.2 + @test derv(baz, 1.7) == 2*2.0*1.7 + 3.0 + + qux(x) = foo(x) + bar(x) + baz(x) + @test derv(qux, 1.7) == (2*2.0*1.7 + 3.0) + 3.1 + 2 +end + +end # module \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 5d58db5c9..0f0b80bc8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,4 +14,8 @@ using Test end include("rules.jl") + + @testset "demos" begin + include("demos/forwarddiffzero.jl") + end end