Skip to content

Commit

Permalink
Reinsert prep for some backends + remove specialization on f (#98)
Browse files Browse the repository at this point in the history
* Reinsert prep for some backends + remove specialization on f

* Remove FDExtras

* Useless type params
  • Loading branch information
gdalle authored Mar 25, 2024
1 parent 8944e2d commit 36d9c47
Show file tree
Hide file tree
Showing 37 changed files with 495 additions and 273 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,13 @@ DI.mode(::AutoReverseChainRules) = ADTypes.AbstractReverseMode

## Primitives

function DI.value_and_pushforward(
f::F, backend::AutoForwardChainRules, x, dx, extras::Nothing
) where {F}
function DI.value_and_pushforward(f, backend::AutoForwardChainRules, x, dx, extras::Nothing)
rc = ruleconfig(backend)
y, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x)
return y, new_dy
end

function DI.value_and_pullback(
f::F, backend::AutoReverseChainRules, x, dy, extras::Nothing
) where {F}
function DI.value_and_pullback(f, backend::AutoReverseChainRules, x, dy, extras::Nothing)
rc = ruleconfig(backend)
y, pullback = rrule_via_ad(rc, f, x)
_, new_dx = pullback(dy)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,6 @@ function test_correctness(ba::AbstractADType, ::typeof(pushforward), scen::Scena

@testset "Primal value" begin
@test myisapprox(y_out, y)
@testset "Mutation" begin
if ismutable(y)
@test myisapprox(y_in, y)
end
end
end
@testset "Tangent value" begin
@test myisapprox(dy_out, dy_true; rtol=1e-3)
Expand Down Expand Up @@ -107,11 +102,6 @@ function test_correctness(ba::AbstractADType, ::typeof(pullback), scen::Scenario

@testset "Primal value" begin
@test myisapprox(y_out, y)
@testset "Mutation" begin
if ismutable(y)
@test myisapprox(y_in, y)
end
end
end
@testset "Cotangent value" begin
@test myisapprox(dx_out, dx_true; rtol=1e-3)
Expand Down Expand Up @@ -157,11 +147,6 @@ function test_correctness(ba::AbstractADType, ::typeof(derivative), scen::Scenar

@testset "Primal value" begin
@test myisapprox(y_out, y)
@testset "Mutation" begin
if ismutable(y)
@test myisapprox(y_in, y)
end
end
end
@testset "Derivative value" begin
@test myisapprox(der_out, der_true; rtol=1e-3)
Expand All @@ -176,7 +161,11 @@ function test_correctness(ba::AbstractADType, ::typeof(gradient), scen::Scenario
grad_true = if x isa Number
ForwardDiff.derivative(f, x)
else
only(Zygote.gradient(f, x))
try
ForwardDiff.gradient(f, x)
catch e
only(Zygote.gradient(f, x))
end
end

y_out1, grad_out1 = value_and_gradient(f, ba, x)
Expand Down Expand Up @@ -223,10 +212,6 @@ function test_correctness(ba::AbstractADType, ::typeof(jacobian), scen::Scenario
@test myisapprox(jac_out2, jac_true; rtol=1e-3)
@test myisapprox(jac_out3, jac_true; rtol=1e-3)
@test myisapprox(jac_out4, jac_true; rtol=1e-3)
@testset "Mutation" begin
@test myisapprox(jac_in2, jac_true; rtol=1e-3)
@test myisapprox(jac_in4, jac_true; rtol=1e-3)
end
end
return test_scen_intact(new_scen, scen)
end
Expand All @@ -242,15 +227,9 @@ function test_correctness(ba::AbstractADType, ::typeof(jacobian), scen::Scenario

@testset "Primal value" begin
@test myisapprox(y_out, y)
@testset "Mutation" begin
@test myisapprox(y_in, y)
end
end
@testset "Jacobian value" begin
@test myisapprox(jac_out, jac_true; rtol=1e-3)
@testset "Mutation" begin
@test myisapprox(jac_in, jac_true; rtol=1e-3)
end
end
return test_scen_intact(new_scen, scen)
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ DI.supports_mutation(::AutoDiffractor) = DI.MutationNotSupported()
DI.mode(::AutoDiffractor) = ADTypes.AbstractForwardMode
DI.mode(::AutoChainRules{<:DiffractorRuleConfig}) = ADTypes.AbstractForwardMode

function DI.value_and_pushforward(f::F, ::AutoDiffractor, x, dx, extras::Nothing) where {F}
function DI.value_and_pushforward(f, ::AutoDiffractor, x, dx, extras::Nothing)
vpff = AD.value_and_pushforward_function(DiffractorForwardBackend(), f, x)
y, dy = vpff((dx,))
return y, dy
Expand Down
4 changes: 1 addition & 3 deletions ext/DifferentiationInterfaceEnzymeExt/forward_allocating.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
## Pushforward

function DI.value_and_pushforward(
f::F, backend::AutoForwardEnzyme, x, dx, extras::Nothing
) where {F}
function DI.value_and_pushforward(f, backend::AutoForwardEnzyme, x, dx, extras::Nothing)
dx_sametype = convert(typeof(x), copy(dx))
y, new_dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx_sametype))
return y, new_dy
Expand Down
4 changes: 2 additions & 2 deletions ext/DifferentiationInterfaceEnzymeExt/forward_mutating.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
## Pushforward

function DI.value_and_pushforward!!(
f!::F, y, dy, backend::AutoForwardEnzyme, x, dx, extras::Nothing
) where {F}
f!, y, dy, backend::AutoForwardEnzyme, x, dx, extras::Nothing
)
dx_sametype = convert(typeof(x), copy(dx))
dy_sametype = convert(typeof(y), dy)
autodiff(
Expand Down
22 changes: 8 additions & 14 deletions ext/DifferentiationInterfaceEnzymeExt/reverse_allocating.jl
Original file line number Diff line number Diff line change
@@ -1,46 +1,40 @@
## Pullback

function DI.value_and_pullback!!(
f::F, _dx, ::AutoReverseEnzyme, x::Number, dy::Number, extras::Nothing
) where {F}
f, _dx, ::AutoReverseEnzyme, x::Number, dy::Number, extras::Nothing
)
der, y = autodiff(ReverseWithPrimal, f, Active, Active(x))
new_dx = dy * only(der)
return y, new_dx
end

function DI.value_and_pullback!!(
f::F, dx, ::AutoReverseEnzyme, x, dy::Number, extras::Nothing
) where {F}
function DI.value_and_pullback!!(f, dx, ::AutoReverseEnzyme, x, dy::Number, extras::Nothing)
dx_sametype = convert(typeof(x), dx)
dx_sametype = myzero!!(dx_sametype)
_, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx_sametype))
dx_sametype = mymul!!(dx_sametype, dy)
return y, myupdate!!(dx, dx_sametype)
end

function DI.value_and_pullback(
f::F, backend::AutoReverseEnzyme, x, dy::Number, extras
) where {F}
function DI.value_and_pullback(f, backend::AutoReverseEnzyme, x, dy::Number, extras)
dx = mysimilar(x)
return DI.value_and_pullback!!(f, dx, backend, x, dy, extras)
end

## Gradient

function DI.gradient(f::F, backend::AutoReverseEnzyme, x, extras::Nothing) where {F}
function DI.gradient(f, backend::AutoReverseEnzyme, x, extras::Nothing)
return gradient(Reverse, f, x)
end

function DI.gradient!!(f::F, grad, backend::AutoReverseEnzyme, x, extras::Nothing) where {F}
function DI.gradient!!(f, grad, backend::AutoReverseEnzyme, x, extras::Nothing)
return gradient!(Reverse, grad, f, x)
end

function DI.gradient(f::F, backend::AutoReverseEnzyme, x::Number, extras::Nothing) where {F}
function DI.gradient(f, backend::AutoReverseEnzyme, x::Number, extras::Nothing)
return autodiff(Reverse, f, Active(x))
end

function DI.gradient!!(
f::F, grad, backend::AutoReverseEnzyme, x::Number, extras::Nothing
) where {F}
function DI.gradient!!(f, grad, backend::AutoReverseEnzyme, x::Number, extras::Nothing)
return autodiff(Reverse, f, Active(x))
end
8 changes: 3 additions & 5 deletions ext/DifferentiationInterfaceEnzymeExt/reverse_mutating.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
## Pullback

function DI.value_and_pullback!!(
f!::F, y, _dx, ::AutoReverseEnzyme, x::Number, dy, extras::Nothing
) where {F}
f!, y, _dx, ::AutoReverseEnzyme, x::Number, dy, extras::Nothing
)
dy_sametype = convert(typeof(y), copy(dy))
_, new_dx = only(autodiff(Reverse, f!, Const, Duplicated(y, dy_sametype), Active(x)))
return y, new_dx
end

function DI.value_and_pullback!!(
f!::F, y, dx, ::AutoReverseEnzyme, x, dy, extras::Nothing
) where {F}
function DI.value_and_pullback!!(f!, y, dx, ::AutoReverseEnzyme, x, dy, extras::Nothing)
dx_sametype = convert(typeof(x), dx)
dx_sametype = myzero!!(dx_sametype)
dy_sametype = convert(typeof(y), copy(dy))
Expand Down
14 changes: 6 additions & 8 deletions ext/DifferentiationInterfaceFastDifferentiationExt/allocating.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## Pushforward

function DI.prepare_pushforward(f::F, ::AutoFastDifferentiation, x) where {F}
function DI.prepare_pushforward(f, ::AutoFastDifferentiation, x)
x_var = if x isa Number
only(make_variables(:x))
else
Expand All @@ -16,8 +16,8 @@ function DI.prepare_pushforward(f::F, ::AutoFastDifferentiation, x) where {F}
end

function DI.value_and_pushforward(
f::F, ::AutoFastDifferentiation, x, dx, jvp_exe::RuntimeGeneratedFunction
) where {F}
f, ::AutoFastDifferentiation, x, dx, jvp_exe::RuntimeGeneratedFunction
)
y = f(x)
v_vec = vcat(myvec(x), myvec(dx))
jv_vec = jvp_exe(v_vec)
Expand All @@ -29,15 +29,13 @@ function DI.value_and_pushforward(
end

function DI.value_and_pushforward(
f::F, backend::AutoFastDifferentiation, x, dx, extras::Nothing
) where {F}
f, backend::AutoFastDifferentiation, x, dx, extras::Nothing
)
jvp_exe = DI.prepare_pushforward(f, backend, x)
return DI.value_and_pushforward(f, backend, x, dx, jvp_exe)
end

function DI.value_and_pushforward!!(
f::F, dy, backend::AutoFastDifferentiation, x, dx, extras
) where {F}
function DI.value_and_pushforward!!(f, dy, backend::AutoFastDifferentiation, x, dx, extras)
y, new_dy = DI.value_and_pushforward(f, backend, x, dx, extras)
return y, myupdate!!(dy, new_dy)
end
12 changes: 6 additions & 6 deletions ext/DifferentiationInterfaceFiniteDiffExt/allocating.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
## Pushforward

function DI.value_and_pushforward!!(
f::F, _dy::Number, ::AutoFiniteDiff{fdtype}, x, dx, extras::Nothing
) where {F,fdtype}
f, _dy::Number, ::AutoFiniteDiff{fdtype}, x, dx, extras::Nothing
) where {fdtype}
y = f(x)
step(t::Number)::Number = f(x .+ t .* dx)
new_dy = finite_difference_derivative(step, zero(eltype(dx)), fdtype, eltype(y), y)
return y, new_dy
end

function DI.value_and_pushforward!!(
f::F, dy::AbstractArray, ::AutoFiniteDiff{fdtype}, x, dx, extras::Nothing
) where {F,fdtype}
f, dy::AbstractArray, ::AutoFiniteDiff{fdtype}, x, dx, extras::Nothing
) where {fdtype}
y = f(x)
step(t::Number)::AbstractArray = f(x .+ t .* dx)
finite_difference_gradient!(
Expand All @@ -21,8 +21,8 @@ function DI.value_and_pushforward!!(
end

function DI.value_and_pushforward(
f::F, ::AutoFiniteDiff{fdtype}, x, dx, extras::Nothing
) where {F,fdtype}
f, ::AutoFiniteDiff{fdtype}, x, dx, extras::Nothing
) where {fdtype}
y = f(x)
step(t::Number) = f(x .+ t .* dx)
new_dy = finite_difference_derivative(step, zero(eltype(dx)), fdtype, eltype(y), y)
Expand Down
4 changes: 2 additions & 2 deletions ext/DifferentiationInterfaceFiniteDiffExt/mutating.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
## Pushforward

function DI.value_and_pushforward!!(
f!::F,
f!,
y::AbstractArray,
dy::AbstractArray,
::AutoFiniteDiff{fdtype},
x,
dx,
extras::Nothing,
) where {F,fdtype}
) where {fdtype}
function step(t::Number)::AbstractArray
new_y = similar(y)
f!(new_y, x .+ t .* dx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ function FiniteDifferences.to_vec(a::OneElement) # TODO: remove type piracy (ht
end

function DI.value_and_pushforward(
f::F, backend::AutoFiniteDifferences{fdm}, x, dx, extras::Nothing
) where {F,fdm}
f, backend::AutoFiniteDifferences{fdm}, x, dx, extras::Nothing
) where {fdm}
y = f(x)
return y, jvp(backend.fdm, f, (x, dx))
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module DifferentiationInterfaceForwardDiffExt

using ADTypes: AbstractADType, AutoForwardDiff
import DifferentiationInterface as DI
using DiffResults: DiffResults
using DiffResults: DiffResults, DiffResult, GradientResult
using ForwardDiff:
Chunk,
Dual,
Expand Down
66 changes: 65 additions & 1 deletion ext/DifferentiationInterfaceForwardDiffExt/allocating.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,72 @@
function DI.value_and_pushforward(f::F, ::AutoForwardDiff, x, dx, extras::Nothing) where {F}
## Pushforward

function DI.value_and_pushforward(f, ::AutoForwardDiff, x, dx, extras::Nothing)
T = tag_type(f, x)
xdual = make_dual(T, x, dx)
ydual = f(xdual)
y = my_value(T, ydual)
new_dy = my_derivative(T, ydual)
return y, new_dy
end

## Gradient

function DI.prepare_gradient(f, backend::AutoForwardDiff, x::AbstractArray)
return GradientConfig(f, x, choose_chunk(backend, x))
end

function DI.value_and_gradient!!(
f, grad::AbstractArray, ::AutoForwardDiff, x::AbstractArray, config::GradientConfig
)
result = DiffResult(zero(eltype(x)), grad)
result = gradient!(result, f, x, config)
return DiffResults.value(result), DiffResults.gradient(result)
end

function DI.value_and_gradient(
f, backend::AutoForwardDiff, x::AbstractArray, config::GradientConfig
)
grad = similar(x)
return DI.value_and_gradient!!(f, grad, backend, x, config)
end

function DI.gradient!!(
f, grad::AbstractArray, ::AutoForwardDiff, x::AbstractArray, config::GradientConfig
)
return gradient!(grad, f, x, config)
end

function DI.gradient(f, ::AutoForwardDiff, x::AbstractArray, config::GradientConfig)
return gradient(f, x, config)
end

## Jacobian

function DI.prepare_jacobian(f, backend::AutoForwardDiff, x::AbstractArray)
return JacobianConfig(f, x, choose_chunk(backend, x))
end

function DI.value_and_jacobian!!(
f, jac::AbstractMatrix, ::AutoForwardDiff, x::AbstractArray, config::JacobianConfig
)
y = f(x)
result = DiffResult(y, jac)
result = jacobian!(result, f, x, config)
return DiffResults.value(result), DiffResults.jacobian(result)
end

function DI.value_and_jacobian(
f, ::AutoForwardDiff, x::AbstractArray, config::JacobianConfig
)
return f(x), jacobian(f, x, config)
end

function DI.jacobian!!(
f, jac::AbstractMatrix, ::AutoForwardDiff, x::AbstractArray, config::JacobianConfig
)
return jacobian!(jac, f, x, config)
end

function DI.jacobian(f, ::AutoForwardDiff, x::AbstractArray, config::JacobianConfig)
return jacobian(f, x, config)
end
Loading

0 comments on commit 36d9c47

Please sign in to comment.