Skip to content

Commit

Permalink
Add ReverseDiff, FiniteDiff and ChainRules backends (#9)
Browse files Browse the repository at this point in the history
* Add ReverseDiff and ChainRules backends

* Add Diffractor and FiniteDiff

* Explanation
  • Loading branch information
gdalle authored Jan 19, 2024
1 parent 9b3713b commit 69ea54c
Show file tree
Hide file tree
Showing 18 changed files with 425 additions and 160 deletions.
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,25 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"

[extensions]
DifferentiationInterfaceChainRulesCoreExt = "ChainRulesCore"
DifferentiationInterfaceEnzymeExt = "Enzyme"
DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
DifferentiationInterfaceForwardDiffExt = "ForwardDiff"
DifferentiationInterfaceReverseDiffExt = "ReverseDiff"

[compat]
ChainRulesCore = "1.19"
DocStringExtensions = "0.9"
FiniteDiff = "2.22"
Enzyme = "0.11"
ForwardDiff = "0.10"
LinearAlgebra = "1"
ReverseDiff = "1.15"
julia = "1.10"

[extras]
Expand Down
15 changes: 14 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ CurrentModule = DifferentiationInterface

Documentation for [DifferentiationInterface](https://github.com/gdalle/DifferentiationInterface.jl).

This is an interface to various autodiff backends for differentiating functions of the form `f(x) = y`, where `x` and `y` are either numbers or arrays.

## Public

```@autodocs
Expand All @@ -20,7 +22,18 @@ Modules = [DifferentiationInterface]
Public = false
```

## Math

Some implementation reminders:

| | pushforward | pullback |
| ---------------- | ---------------------------------------------------------------- | ------------------------------------------------------------------ |
| scalar -> scalar | derivative multiplied by the tangent | derivative multiplied by the cotangent |
| scalar -> vector | derivative vector multiplied componentwise by the tangent vector | dot product between the derivative vector and the cotangent vector |
| vector -> scalar | dot product between the gradient vector and the tangent vector | gradient vector multiplied componentwise by the cotangent |
| vector -> vector | Jacobian matrix multiplied by the tangent vector | transposed Jacobian matrix multiplied by the cotangent vector |

## Index

```@index
```
```
38 changes: 38 additions & 0 deletions ext/DifferentiationInterfaceChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,42 @@ using ChainRulesCore
using DifferentiationInterface
using LinearAlgebra

ruleconfig(backend::ChainRulesBackend) = backend.ruleconfig

function DifferentiationInterface.pushforward!(
dy::Y, backend::ChainRulesBackend{<:RuleConfig{>:HasForwardsMode}}, f, x::X, dx::X
) where {X,Y<:Number}
rc = ruleconfig(backend)
_, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x)
return new_dy
end

function DifferentiationInterface.pushforward!(
dy::Y, backend::ChainRulesBackend{<:RuleConfig{>:HasForwardsMode}}, f, x::X, dx::X
) where {X,Y<:AbstractArray}
rc = ruleconfig(backend)
_, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x)
dy .= new_dy
return dy
end

function DifferentiationInterface.pullback!(
dx::X, backend::ChainRulesBackend{<:RuleConfig{>:HasReverseMode}}, f, x::X, dy::Y
) where {X<:Number,Y}
rc = ruleconfig(backend)
_, pullback = rrule_via_ad(rc, f, x)
_, new_dx = pullback(dy)
return new_dx
end

function DifferentiationInterface.pullback!(
dx::X, backend::ChainRulesBackend{<:RuleConfig{>:HasReverseMode}}, f, x::X, dy::Y
) where {X<:AbstractArray,Y}
rc = ruleconfig(backend)
_, pullback = rrule_via_ad(rc, f, x)
_, new_dx = pullback(dy)
dx .= new_dx
return dx
end

end
13 changes: 8 additions & 5 deletions ext/DifferentiationInterfaceEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ using Enzyme
"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.jvp!(dy::Y, ::EnzymeBackend, f, x::X, dx::X) where {X,Y}
function DifferentiationInterface.pushforward!(
dy::Y, ::EnzymeBackend, f, x::X, dx::X
) where {X,Y}
return only(autodiff(Forward, f, DuplicatedNoNeed, Duplicated(x, dx)))
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.vjp!(
function DifferentiationInterface.pullback!(
dx::X, ::EnzymeBackend, f, x::X, dy::Y
) where {X<:Number,Y<:Union{Real,Nothing}}
return only(first(autodiff(Reverse, f, Active, Active(x)))) * dy
Expand All @@ -23,11 +25,12 @@ end
"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.vjp!(
function DifferentiationInterface.pullback!(
dx::X, ::EnzymeBackend, f, x::X, dy::Y
) where {X,Y<:Union{Real,Nothing}}
) where {X<:AbstractArray,Y<:Union{Real,Nothing}}
dx .= zero(eltype(dx))
autodiff(Reverse, f, Active, Duplicated(x, dx))
dx .*= dy # TODO: doesn't work with arbitrary dx
dx .*= dy
return dx
end

Expand Down
97 changes: 97 additions & 0 deletions ext/DifferentiationInterfaceFiniteDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
module DifferentiationInterfaceFiniteDiffExt

using DifferentiationInterface
using DocStringExtensions
using FiniteDiff
using LinearAlgebra

## Pushforward

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pushforward!(
dy::Y, ::FiniteDiffBackend, f, x::X, dx::X
) where {X<:Number,Y<:Number}
new_dy = FiniteDiff.finite_difference_derivative(f, x) * dx
return new_dy
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pushforward!(
dy::Y, ::FiniteDiffBackend, f, x::X, dx::X
) where {X<:Number,Y<:AbstractArray}
new_dy = FiniteDiff.finite_difference_derivative(f, x)
dy .= new_dy .* dx
return dy
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pushforward!(
dy::Y, ::FiniteDiffBackend, f, x::X, dx::X
) where {X<:AbstractArray,Y<:Number}
g = FiniteDiff.finite_difference_gradient(f, x)
new_dy = dot(g, dx)
return new_dy
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pushforward!(
dy::Y, ::FiniteDiffBackend, f, x::X, dx::X
) where {X<:AbstractArray,Y<:AbstractArray}
J = FiniteDiff.finite_difference_jacobian(f, x)
mul!(dy, J, dx)
return dy
end

## Pullback

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pullback!(
dx::X, ::FiniteDiffBackend, f, x::X, dy::Y
) where {X<:Real,Y<:Real}
new_dx = dy * FiniteDiff.finite_difference_derivative(f, x)
return new_dx
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pullback!(
dx::X, ::FiniteDiffBackend, f, x::X, dy::Y
) where {X<:Real,Y<:AbstractArray}
new_dx = dot(dy, FiniteDiff.finite_difference_derivative(f, x))
return new_dx
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pullback!(
dx::X, ::FiniteDiffBackend, f, x::X, dy::Y
) where {X<:AbstractArray,Y<:Real}
g = FiniteDiff.finite_difference_gradient(f, x)
dx .= g .* dy
return dx
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pullback!(
dx::X, ::FiniteDiffBackend, f, x::X, dy::Y
) where {X<:AbstractArray,Y<:AbstractArray}
J = FiniteDiff.finite_difference_jacobian(f, x)
mul!(dx, transpose(J), dy)
return dx
end

end
74 changes: 6 additions & 68 deletions ext/DifferentiationInterfaceForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,10 @@ using DocStringExtensions
using ForwardDiff
using LinearAlgebra

## JVP

"""
$(TYPEDSIGNATURES)
JVP for a scalar -> scalar function: derivative multiplied by the tangent.
"""
function DifferentiationInterface.jvp!(
function DifferentiationInterface.pushforward!(
dy::Y, ::ForwardDiffBackend, f, x::X, dx::X
) where {X<:Real,Y<:Real}
new_dy = ForwardDiff.derivative(f, x) * dx
Expand All @@ -21,10 +17,8 @@ end

"""
$(TYPEDSIGNATURES)
JVP for a scalar -> vector function: derivative vector multiplied componentwise by the tangent.
"""
function DifferentiationInterface.jvp!(
function DifferentiationInterface.pushforward!(
dy::Y, ::ForwardDiffBackend, f, x::X, dx::X
) where {X<:Real,Y<:AbstractArray}
ForwardDiff.derivative!(dy, f, x)
Expand All @@ -34,80 +28,24 @@ end

"""
$(TYPEDSIGNATURES)
JVP for a vector -> scalar function: dot product between the gradient vector and the tangent vector.
"""
function DifferentiationInterface.jvp!(
function DifferentiationInterface.pushforward!(
dy::Y, ::ForwardDiffBackend, f, x::X, dx::X
) where {X<:AbstractArray,Y<:Real}
g = ForwardDiff.gradient(f, x) # TODO: allocates
g = ForwardDiff.gradient(f, x) # TODO: replace with duals, n times too slow
new_dy = dot(g, dx)
return new_dy
end

"""
$(TYPEDSIGNATURES)
JVP for a vector -> vector function: Jacobian matrix multiplied by the tangent vector.
"""
function DifferentiationInterface.jvp!(
function DifferentiationInterface.pushforward!(
dy::Y, ::ForwardDiffBackend, f, x::X, dx::X
) where {X<:AbstractArray,Y<:AbstractArray}
J = ForwardDiff.jacobian(f, x) # TODO: allocates
J = ForwardDiff.jacobian(f, x) # TODO: replace with duals, n times too slow
mul!(dy, J, dx)
return dy
end

## VJP

"""
$(TYPEDSIGNATURES)
VJP for a scalar -> scalar function: derivative multiplied by the cotangent.
"""
function DifferentiationInterface.vjp!(
dx::X, ::ForwardDiffBackend, f, x::X, dy::Y
) where {X<:Real,Y<:Real}
new_dx = dy * ForwardDiff.derivative(f, x)
return new_dx
end

"""
$(TYPEDSIGNATURES)
VJP for a scalar -> vector function: dot product between the derivative vector and the cotangent vector.
"""
function DifferentiationInterface.vjp!(
dx::X, ::ForwardDiffBackend, f, x::X, dy::Y
) where {X<:Real,Y<:AbstractArray}
new_dx = dot(dy, ForwardDiff.derivative(f, x)) # TODO: allocates
return new_dx
end

"""
$(TYPEDSIGNATURES)
VJP for a vector -> scalar function: gradient vector multiplied componentwise by the cotangent.
"""
function DifferentiationInterface.vjp!(
dx::X, ::ForwardDiffBackend, f, x::X, dy::Y
) where {X<:AbstractArray,Y<:Real}
ForwardDiff.gradient!(dx, f, x)
dx .*= dy
return dx
end

"""
$(TYPEDSIGNATURES)
VJP for a vector -> vector function: transposed Jacobian matrix multiplied by the cotangent vector.
"""
function DifferentiationInterface.vjp!(
dx::X, ::ForwardDiffBackend, f, x::X, dy::Y
) where {X<:AbstractArray,Y<:AbstractArray}
J = ForwardDiff.jacobian(f, x) # TODO: allocates
mul!(dx, transpose(J), dy)
return dx
end

end
30 changes: 30 additions & 0 deletions ext/DifferentiationInterfaceReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
module DifferentiationInterfaceReverseDiffExt

using DifferentiationInterface
using DocStringExtensions
using ReverseDiff
using LinearAlgebra

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pullback!(
dx::X, ::ReverseDiffBackend, f, x::X, dy::Y
) where {X<:AbstractArray,Y<:Real}
ReverseDiff.gradient!(dx, f, x)
dx .*= dy
return dx
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pullback!(
dx::X, ::ReverseDiffBackend, f, x::X, dy::Y
) where {X<:AbstractArray,Y<:AbstractArray}
J = ReverseDiff.jacobian(f, x)
mul!(dx, transpose(J), dy)
return dx
end

end
Loading

0 comments on commit 69ea54c

Please sign in to comment.