-
-
Notifications
You must be signed in to change notification settings - Fork 399
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DNMY: Enzyme extension #3712
DNMY: Enzyme extension #3712
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #3712 +/- ##
==========================================
- Coverage 98.37% 97.62% -0.75%
==========================================
Files 43 44 +1
Lines 5736 5780 +44
==========================================
Hits 5643 5643
- Misses 93 137 +44 ☔ View full report in Codecov by Sentry. |
Okay, this needs some discussion, likely at a monthly developer call. Let's also put aside the exact syntax. Instead of pirating a method like this, we'd need to add some sort of type or flag for people to opt-in, but that is a small issue that can be resolved. I am somewhat in favor of this, but given the experience of #3707, I think we should be very careful about adding this. Particularly relevant is this discussion: #3413 (comment) I would be strongly in favor of making a requirement that new extensions must have a 1.0 release, and have no plans for a 2.0 release. This would rule out anything that has moved from v1.0.0 to v5.67.2 in a short time period, and it would rule out Enzyme, which is on v0.11 Another option is that we add a page to the documentation which shows how to construct the appropriate gradient and hessian oracles, but we don't add this to JuMP, either directly or as an extension. It's also worth evaluating the cost on compilation times for the tests and documentation if we add this. Enzyme is pretty heavy. |
Also, using your code I get: julia> using Enzyme
julia> function jump_operator(f::Function)
@inline function f!(y, x...)
y[1] = f(x...)
end
function gradient!(g::AbstractVector{T}, x::Vararg{T,N}) where {T,N}
y = zeros(T,1)
ry = ones(T,1)
rx = ntuple(N) do i
Active(x[i])
end
g .= autodiff(Reverse, f!, Const, Duplicated(y,ry), rx...)[1][2:end]
return nothing
end
function gradient_deferred!(g, y, ry, rx...)
g .= autodiff_deferred(Reverse, f!, Const, Duplicated(y,ry), rx...)[1][2:end]
return nothing
end
function hessian!(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N}
y = zeros(T,1)
dy = ntuple(N) do i
ones(1)
end
g = zeros(T,N)
dg = ntuple(N) do i
zeros(T,N)
end
ry = ones(1)
dry = ntuple(N) do i
zeros(T,1)
end
rx = ntuple(N) do i
Active(x[i])
end
args = ntuple(N) do i
drx = ntuple(N) do j
if i == j
Active(one(T))
else
Active(zero(T))
end
end
BatchDuplicated(rx[i], drx)
end
autodiff(Forward, gradient_deferred!, Const, BatchDuplicated(g,dg), BatchDuplicated(y,dy), BatchDuplicated(ry, dry), args...)
for i in 1:N
for j in 1:N
if i <= j
H[j,i] = dg[j][i]
end
end
end
return nothing
end
return gradient!, hessian!
end
jump_operator (generic function with 1 method)
julia> foo(x...) = log(sum(exp(x[i]) for i in 1:length(x)))
foo (generic function with 1 method)
julia> ∇foo, ∇²foo = jump_operator(foo)
(var"#gradient!#9"{var"#f!#8"{typeof(foo)}}(var"#f!#8"{typeof(foo)}(foo)), var"#hessian!#12"{var"#gradient_deferred!#11"{var"#f!#8"{typeof(foo)}}}(var"#gradient_deferred!#11"{var"#f!#8"{typeof(foo)}}(var"#f!#8"{typeof(foo)}(foo))))
julia> N = 3
3
julia> x = rand(N)
3-element Vector{Float64}:
0.23712902725864782
0.6699243680780806
0.530669076854107
julia> g = zeros(N)
3-element Vector{Float64}:
0.0
0.0
0.0
julia> H = zeros(N, N)
3×3 Matrix{Float64}:
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
julia> foo(x...)
1.593666919840647
julia> ∇foo(g, x...)
julia> ∇²foo(H, x...)
ERROR: Attempting to call an indirect active function whose runtime value is inactive:
Backtrace
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5378
[2] enzyme_call
@ ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5056
[3] AugmentedForwardThunk
@ ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5009
[4] runtime_generic_augfwd
@ ~/.julia/packages/Enzyme/wR2t7/src/rules/jitrules.jl:179
[5] runtime_generic_augfwd
@ ~/.julia/packages/Enzyme/wR2t7/src/rules/jitrules.jl:0
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5378 [inlined]
[2] enzyme_call
@ ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5056 [inlined]
[3] AugmentedForwardThunk
@ ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5009 [inlined]
[4] runtime_generic_augfwd
@ ~/.julia/packages/Enzyme/wR2t7/src/rules/jitrules.jl:179 [inlined]
[5] runtime_generic_augfwd
@ ~/.julia/packages/Enzyme/wR2t7/src/rules/jitrules.jl:0 [inlined]
[6] fwddiffe3julia_runtime_generic_augfwd_3727_inner_1wrap
@ ~/.julia/packages/Enzyme/wR2t7/src/rules/jitrules.jl:0
[7] macro expansion
@ ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5378 [inlined]
[8] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::BatchDuplicated{…}, ::BatchDuplicated{…}, ::BatchDuplicated{…}, ::BatchDuplicated{…}, ::BatchDuplicated{…}, ::BatchDuplicated{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5056
[9] (::Enzyme.Compiler.ForwardModeThunk{…})(::Const{…}, ::Const{…}, ::Vararg{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5001
[10] runtime_generic_fwd(activity::Type{…}, width::Val{…}, RT::Val{…}, f::typeof(Enzyme.Compiler.runtime_generic_augfwd), df::Nothing, df_2::Nothing, df_3::Nothing, primal_1::Type{…}, shadow_1_1::Nothing, shadow_1_2::Nothing, shadow_1_3::Nothing, primal_2::Val{…}, shadow_2_1::Nothing, shadow_2_2::Nothing, shadow_2_3::Nothing, primal_3::Val{…}, shadow_3_1::Nothing, shadow_3_2::Nothing, shadow_3_3::Nothing, primal_4::Val{…}, shadow_4_1::Nothing, shadow_4_2::Nothing, shadow_4_3::Nothing, primal_5::typeof(foo), shadow_5_1::Nothing, shadow_5_2::Nothing, shadow_5_3::Nothing, primal_6::Nothing, shadow_6_1::Nothing, shadow_6_2::Nothing, shadow_6_3::Nothing, primal_7::Float64, shadow_7_1::Float64, shadow_7_2::Float64, shadow_7_3::Float64, primal_8::Base.RefValue{…}, shadow_8_1::Base.RefValue{…}, shadow_8_2::Base.RefValue{…}, shadow_8_3::Base.RefValue{…}, primal_9::Float64, shadow_9_1::Float64, shadow_9_2::Float64, shadow_9_3::Float64, primal_10::Base.RefValue{…}, shadow_10_1::Base.RefValue{…}, shadow_10_2::Base.RefValue{…}, shadow_10_3::Base.RefValue{…}, primal_11::Float64, shadow_11_1::Float64, shadow_11_2::Float64, shadow_11_3::Float64, primal_12::Base.RefValue{…}, shadow_12_1::Base.RefValue{…}, shadow_12_2::Base.RefValue{…}, shadow_12_3::Base.RefValue{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/wR2t7/src/rules/jitrules.jl:116
[11] f! |
Here's some code I had when experimenting with this: abstract type AbstractADOperator end
#=
Enzyme
=#
import Enzyme
struct ADOperatorEnzyme <: AbstractADOperator end
function create_operator(f::Function, ::ADOperatorEnzyme)
@inline f!(y, x::Vararg{T,N}) where {T,N} = (y[1] = f(x...))
function ∇f!(g::AbstractVector{T}, x::Vararg{T,N}) where {T,N}
g .= Enzyme.autodiff(
Enzyme.Reverse,
f!,
Enzyme.Const,
Enzyme.Duplicated(zeros(T, 1), ones(T, 1)),
Enzyme.Active.(x)...,
)[1][2:end]
return
end
function ∇²f!(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N}
dg = ntuple(_ -> zeros(T, N), N)
args = ntuple(N) do i
return Enzyme.BatchDuplicated(
Enzyme.Active(x[i]),
ntuple(j -> Enzyme.Active(T(i == j)), N),
)
end
function gradient_deferred!(g, y, ry, rx...)
g .= Enzyme.autodiff_deferred(
Enzyme.Reverse,
f!,
Enzyme.Const,
Enzyme.Duplicated(y, ry),
rx...,
)[1][2:end]
return
end
Enzyme.autodiff(
Enzyme.Forward,
gradient_deferred!,
Enzyme.Const,
Enzyme.BatchDuplicated(zeros(T, N), dg),
Enzyme.BatchDuplicated(zeros(T, 1), ntuple(_ -> ones(T, 1), N)),
Enzyme.BatchDuplicated(ones(T, 1), ntuple(_ -> zeros(T, 1), N)),
args...,
)
for j in 1:N, i in 1:j
H[j, i] = dg[j][i]
end
return
end
return f, ∇f!, ∇²f!
end
#=
ForwardDiff
=#
import ForwardDiff
struct ADOperatorForwardDiff <: AbstractADOperator end
function create_operator(f::Function, ::ADOperatorForwardDiff)
function ∇f!(g::AbstractVector{T}, x::Vararg{T,N}) where {T,N}
ForwardDiff.gradient!(g, y -> f(y...), collect(x))
return
end
function ∇²f!(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N}
h = ForwardDiff.hessian(y -> f(y...), collect(x))
for i in 1:N, j in 1:i
H[i, j] = h[i, j]
end
return
end
return f, ∇f!, ∇²f!
end
#=
Examples
=#
import LinearAlgebra
using Test
function example_logsumexp()
f(x...) = log(sum(exp(x[i]) for i in 1:length(x)))
∇f(g, x...) = g .= exp.(x) ./ sum(exp.(x))
function ∇²f(H, x...)
y = collect(x)
g = exp.(y) / sum(exp.(y))
h = LinearAlgebra.diagm(g) - g * g'
for i in 1:length(y), j in 1:i
H[i, j] = h[i, j]
end
return
end
return f, ∇f, ∇²f
end
function example_rosenbrock()
f(x, y) = (1 - x)^2 + 100 * (y - x^2)^2
function ∇f(g, x, y)
g[1] = 2 * (-1 + x - 200 * (y * x + -x^3))
g[2] = 200 * (y - x^2)
return
end
function ∇²f(H, x, y)
H[1, 1] = 2 + 1200 * x^2 - 400 * y
H[2, 1] = -400 * x
H[2, 2] = 200
return
end
return f, ∇f, ∇²f
end
function test_example(example, N, config::AbstractADOperator)
true_f, true_∇f, true_∇²f = example()
f, ∇f, ∇²f = create_operator(true_f, config)
x = rand(N)
y = f(x...)
true_y = true_f(x...)
@test isapprox(y, true_y)
g, true_g = zeros(N), zeros(N)
∇f(g, x...)
true_∇f(true_g, x...)
@test isapprox(g, true_g)
H, true_H = zeros(N, N), zeros(N, N)
∇²f(H, x...)
true_∇²f(true_H, x...)
@test isapprox(H, true_H)
return
end
@testset "Examples" begin
for config in (ADOperatorForwardDiff(), ADOperatorEnzyme())
for (example, N) in (
example_rosenbrock => 2,
example_logsumexp => 3,
example_logsumexp => 20,
)
@testset "$example - $N - $config" begin
test_example(example, N, config)
end
end
end
end Running yields Examples | 16 2 18 11.3s
example_rosenbrock - 2 - ADOperatorForwardDiff() | 3 3 1.0s
example_logsumexp - 3 - ADOperatorForwardDiff() | 3 3 1.1s
example_logsumexp - 20 - ADOperatorForwardDiff() | 3 3 1.2s
example_rosenbrock - 2 - ADOperatorEnzyme() | 3 3 0.4s
example_logsumexp - 3 - ADOperatorEnzyme() | 2 1 3 1.0s
example_logsumexp - 20 - ADOperatorEnzyme() | 2 1 3 6.6s
ERROR: LoadError: Some tests did not pass: 16 passed, 0 failed, 2 errored, 0 broken. |
Okay, I've tightened things up considerably, and got rid of the Hessian error: abstract type AbstractADOperator end
#=
Enzyme
=#
import Enzyme
struct ADOperatorEnzyme <: AbstractADOperator end
function create_operator(f::Function, ::ADOperatorEnzyme)
function ∇f!(g::AbstractVector{T}, x::Vararg{T,N}) where {T,N}
g .= Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Active.(x)...)[1]
return
end
function ∇²f!(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N}
direction(i) = ntuple(j -> Enzyme.Active(T(i == j)), N)
hess = Enzyme.autodiff(
Enzyme.Forward,
(x...) -> Enzyme.autodiff_deferred(Enzyme.Reverse, f, x...)[1],
Enzyme.BatchDuplicated.(Enzyme.Active.(x), ntuple(direction, N))...,
)[1]
for j in 1:N, i in 1:j
H[j, i] = hess[j][i]
end
return
end
return f, ∇f!, ∇²f!
end
#=
ForwardDiff
=#
import ForwardDiff
struct ADOperatorForwardDiff <: AbstractADOperator end
function create_operator(f::Function, ::ADOperatorForwardDiff)
function ∇f!(g::AbstractVector{T}, x::Vararg{T,N}) where {T,N}
ForwardDiff.gradient!(g, y -> f(y...), collect(x))
return
end
function ∇²f!(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N}
h = ForwardDiff.hessian(y -> f(y...), collect(x))
for i in 1:N, j in 1:i
H[i, j] = h[i, j]
end
return
end
return f, ∇f!, ∇²f!
end
#=
Examples
=#
import LinearAlgebra
using Test
function example_logsumexp()
f(x...) = log(sum(exp(x[i]) for i in 1:length(x)))
∇f(g, x...) = g .= exp.(x) ./ sum(exp.(x))
function ∇²f(H, x...)
y = collect(x)
g = exp.(y) / sum(exp.(y))
h = LinearAlgebra.diagm(g) - g * g'
for i in 1:length(y), j in 1:i
H[i, j] = h[i, j]
end
return
end
return f, ∇f, ∇²f
end
function example_rosenbrock()
f(x, y) = (1 - x)^2 + 100 * (y - x^2)^2
function ∇f(g, x, y)
g[1] = 2 * (-1 + x - 200 * (y * x + -x^3))
g[2] = 200 * (y - x^2)
return
end
function ∇²f(H, x, y)
H[1, 1] = 2 + 1200 * x^2 - 400 * y
H[2, 1] = -400 * x
H[2, 2] = 200
return
end
return f, ∇f, ∇²f
end
function test_example(example, N, config::AbstractADOperator)
true_f, true_∇f, true_∇²f = example()
f, ∇f, ∇²f = create_operator(true_f, config)
x = rand(N)
y = f(x...)
true_y = true_f(x...)
@test isapprox(y, true_y)
g, true_g = zeros(N), zeros(N)
∇f(g, x...)
true_∇f(true_g, x...)
@test isapprox(g, true_g)
H, true_H = zeros(N, N), zeros(N, N)
∇²f(H, x...)
true_∇²f(true_H, x...)
@test isapprox(H, true_H)
return
end
@testset "Examples" begin
for config in (ADOperatorForwardDiff(), ADOperatorEnzyme())
for (example, N) in (
example_rosenbrock => 2,
example_logsumexp => 3,
example_logsumexp => 20,
)
@testset "$example - $N - $config" begin
test_example(example, N, config)
end
end
end
end |
Okay, so since this is 20 lines of code, I think this might better as a tutorial in the documentation. @blegat has asked for this before: #2348 (comment) It'll also let us show off Enzyme and ForwardDiff. I'll take a stab, and then we can discuss the relative merits of having the code as a JuMP extension vs asking people to copy-paste a snippet. |
Developer call says that the documentation https://jump.dev/JuMP.jl/dev/tutorials/nonlinear/operator_ad/ is sufficient. |
I made this PR initially to Enzyme EnzymeAD/Enzyme.jl#1337 , but @wsmoses recommended to make it an extension of JuMP. Let me know if this works and I can add this as a test.
This extends JuMP and allows a user in JuMP to differentiate an external function using Enzyme.
Use case: