Skip to content

Commit

Permalink
add ForwardDiffZero as an API integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Jul 21, 2020
1 parent d1ca907 commit f6a7179
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
76 changes: 76 additions & 0 deletions test/demos/forwarddiffzero.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"The simplest viable forward mode a AD, only supports `Float64`"
module ForwardDiffZero

using ChainRulesCore
using Test

struct Dual <: Real
primal::Float64
diff::Float64
end

primal(d::Dual) = d.primal
diff(d::Dual) = d.diff

primal(d::Real) = d
diff(d::Real) = 0.0

# needed for ^ to work from having `*` defined
Base.to_power_type(x::Dual) = x


function define_dual_overload(sig, ast)
opT, argTs = Iterators.peel(sig.parameters)
fieldcount(opT) == 0 || return # not handling functors
all(Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops.
op = opT.instance
opname = :($(parentmodule(op)).$(nameof(op)))
N = length(sig.parameters) - 1 # skip the op
fdef = quote
function $opname(dual_args::Vararg{Union{Dual, Float64},$N}; kwargs...)
ȧrgs = (NO_FIELDS, diff.(dual_args)...)
args = ($opname, primal.(dual_args)...)
res = frule(ȧrgs, args...; kwargs...)
res === nothing && error("Apparently no rule for $($sig)), but we really thought there was, args=($args)")
y, ẏ = res
return Dual(y, ẏ) # if y, ẏ are not `Float64` this will error.
end
end
# @show fdef
@eval $fdef
end

#########################################
# Initial rule setup
@scalar_rule x + y (One(), One())
@scalar_rule x - y (One(), -1)

on_new_rule(frule) do sig, ast
return define_dual_overload(sig, ast)
end

# add a rule later also
@frule function ChainRulesCore.frule((_, Δx, Δy), ::typeof(*), x::Number, y::Number)
return (x * y, (Δx * y + x * Δy))
end

function derv(f, args...)
duals = Dual.(args, one.(args))
return diff(f(duals...))
end

@testset "ForwardDiffZero" begin
foo(x) = x + x
@test derv(foo, 1.6) == 2

bar(x) = x + 2.1 * x
@test derv(bar, 1.2) == 3.1

baz(x) = 2.0 * x^2 + 3.0*x + 1.2
@test derv(baz, 1.7) == 2*2.0*1.7 + 3.0

qux(x) = foo(x) + bar(x) + baz(x)
@test derv(qux, 1.7) == (2*2.0*1.7 + 3.0) + 3.1 + 2
end

end # module
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,8 @@ using Test
end

include("rules.jl")

@testset "demos" begin
include("demos/forwarddiffzero.jl")
end
end

0 comments on commit f6a7179

Please sign in to comment.