Skip to content

Commit

Permalink
[docs] add DifferentiationInterface to autodiff tutorial (#3836)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Oct 7, 2024
1 parent b94a811 commit 1aa0ff7
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 6 deletions.
4 changes: 3 additions & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ CDDLib = "3391f64e-dcde-5f30-b752-e11513730f60"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
Clarabel = "61c947e1-3e6d-4ee4-985a-eec8c727bd6e"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Expand Down Expand Up @@ -43,12 +44,13 @@ CDDLib = "=0.9.4"
CSV = "0.10"
Clarabel = "=0.9.0"
DataFrames = "1"
DifferentiationInterface = "0.6.5"
DimensionalData = "0.27.3"
Distributions = "0.25"
Documenter = "=1.6.0"
DocumenterCitations = "1"
Dualization = "0.5"
Enzyme = "0.12.14"
Enzyme = "0.13.7"
ForwardDiff = "0.10"
GLPK = "=1.2.1"
HTTP = "1.5.4"
Expand Down
113 changes: 108 additions & 5 deletions docs/src/tutorials/nonlinear/operator_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
# This tutorial uses the following packages:

using JuMP
import DifferentiationInterface
import Enzyme
import ForwardDiff
import Ipopt
Expand Down Expand Up @@ -248,18 +249,18 @@ Test.@test ≈(analytic_g, enzyme_g)
# differentiation.

# The code to implement the Hessian in Enzyme is complicated, so we will not
# explain it in detail; see the [Enzyme documentation](https://enzymead.github.io/Enzyme.jl/v0.11.20/generated/autodiff/#Vector-forward-over-reverse).
# explain it in detail; see the [Enzyme documentation](https://enzymead.github.io/Enzyme.jl/stable/generated/autodiff/#Vector-forward-over-reverse).

function enzyme_∇²f(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N}
## direction(i) returns a tuple with a `1` in the `i`'th entry and `0`
## otherwise
direction(i) = ntuple(j -> Enzyme.Active(T(i == j)), N)
## As the inner function, compute the gradient using Reverse mode
f_deferred(x...) = Enzyme.autodiff_deferred(Enzyme.Reverse, f, x...)[1]
f(x...) = Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Active, x...)[1]
## For the outer autodiff, use Forward mode.
hess = Enzyme.autodiff(
Enzyme.Forward,
f_deferred,
f,
## Compute multiple evaluations of Forward mode, each time using `x` but
## initializing with a different direction.
Enzyme.BatchDuplicated.(Enzyme.Active.(x), ntuple(direction, N))...,
Expand Down Expand Up @@ -296,10 +297,10 @@ function enzyme_derivatives(f::Function)
end
function ∇²f(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N}
direction(i) = ntuple(j -> Enzyme.Active(T(i == j)), N)
f_deferred(x...) = Enzyme.autodiff_deferred(Enzyme.Reverse, f, x...)[1]
f(x...) = Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Active, x...)[1]
hess = Enzyme.autodiff(
Enzyme.Forward,
f_deferred,
f,
Enzyme.BatchDuplicated.(Enzyme.Active.(x), ntuple(direction, N))...,
)[1]
for j in 1:N, i in 1:j
Expand All @@ -324,3 +325,105 @@ function enzyme_rosenbrock()
end

enzyme_rosenbrock()

# ## DifferentiationInterface

# Julia offers [many different autodiff packages](https://juliadiff.org/).
# [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl)
# is a package that provides an abstraction layer across a few underlying
# autodiff libraries.

# All the necessary information about your choice of underlying autodiff
# package is encoded in a "backend object" like this one:

DifferentiationInterface.AutoForwardDiff()

# This type comes from another package called [ADTypes.jl](https://github.com/SciML/ADTypes.jl),
# but DifferentiationInterface re-exports it. Other options include
# `AutoZygote()` and `AutoFiniteDiff()`.

# ### Gradient

# Apart from providing the backend object, the syntax below remains very
# similar:

function di_∇f(
g::AbstractVector{T},
x::Vararg{T,N};
backend = DifferentiationInterface.AutoForwardDiff(),
) where {T,N}
DifferentiationInterface.gradient!(splat(f), g, backend, collect(x))
return
end

# Let's check that we find the analytic solution:

di_g = zeros(2)
di_∇f(di_g, x...)
Test.@test (analytic_g, di_g)

# ### Hessian

# The Hessian follows exactly the same logic, except we need only the lower
# triangle.

function di_∇²f(
H::AbstractMatrix{T},
x::Vararg{T,N};
backend = DifferentiationInterface.AutoForwardDiff(),
) where {T,N}
H_dense = DifferentiationInterface.hessian(splat(f), backend, collect(x))
for i in 1:N, j in 1:i
H[i, j] = H_dense[i, j]
end
return
end

# Let's check that we find the analytic solution:

di_H = zeros(2, 2)
di_∇²f(di_H, x...)
Test.@test (analytic_H, di_H)

# ### JuMP example

# The code for computing the gradient and Hessian using DifferentiationInterface
# can be re-used for many operators. Thus, it is helpful to encapsulate it into
# the function:

"""
di_derivatives(f::Function; backend) -> Tuple{Function,Function}
Return a tuple of functions that evaluate the gradient and Hessian of `f` using
DifferentiationInterface.jl with any given `backend`.
"""
function di_derivatives(f::Function; backend)
function ∇f(g::AbstractVector{T}, x::Vararg{T,N}) where {T,N}
DifferentiationInterface.gradient!(splat(f), g, backend, collect(x))
return
end
function ∇²f(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N}
H_dense =
DifferentiationInterface.hessian(splat(f), backend, collect(x))
for i in 1:N, j in 1:i
H[i, j] = H_dense[i, j]
end
return
end
return ∇f, ∇²f
end

# Here's an example using `di_derivatives`:

function di_rosenbrock(; backend)
model = Model(Ipopt.Optimizer)
set_silent(model)
@variable(model, x[1:2])
@operator(model, op_rosenbrock, 2, f, di_derivatives(f; backend)...)
@objective(model, Min, op_rosenbrock(x[1], x[2]))
optimize!(model)
Test.@test is_solved_and_feasible(model)
return value.(x)
end

di_rosenbrock(; backend = DifferentiationInterface.AutoForwardDiff())

0 comments on commit 1aa0ff7

Please sign in to comment.